function test_suite = test_naive_bayes_classifier_searchlight
% tests for cosmo_naive_bayes_classifier_searchlight
%
% # For CoSMoMVPA's copyright information and license terms, #
% # see the COPYING file distributed with CoSMoMVPA. #
try % assignment of 'localfunctions' is necessary in Matlab >= 2016
test_functions = localfunctions();
catch % no problem; early Matlab versions can use initTestSuite fine
end
initTestSuite;
function test_naive_bayes_classifier_searchlight_tiny
ds = cosmo_synthetic_dataset('ntargets', 25, 'nchunks', 4, 'size', 'small');
nh = cosmo_spherical_neighborhood(ds, 'radius', 1, 'progress', false);
nh.neighbors = {1:size(ds.samples, 2)};
nh.fa = cosmo_slice(nh.fa, 1, 2, 'struct');
opt = struct();
opt.partitions = cosmo_nfold_partitioner(ds);
opt.output = 'winner_predictions';
opt.progress = false;
x = cosmo_naive_bayes_classifier_searchlight(ds, nh, opt);
assertEqual(x.fa, nh.fa);
assertEqual(x.a, nh.a);
nfolds = numel(opt.partitions.train_indices);
for fold = 1:nfolds
tr = opt.partitions.train_indices{fold};
te = opt.partitions.test_indices{fold};
pred = cosmo_classify_naive_bayes(ds.samples(tr, :), ...
ds.sa.targets(tr), ...
ds.samples(te, :));
y = x.samples(te, :);
assertEqual(y, pred);
end
% compare with standard searchlight
assert_same_output_as_classifical_searchlight(ds, nh, opt);
function test_naive_bayes_classifier_searchlight_basics
ds = cosmo_synthetic_dataset('ntargets', 25, 'nchunks', 4, 'size', 'small');
nh = cosmo_spherical_neighborhood(ds, 'radius', 1, 'progress', false);
nh.neighbors = {1:size(ds.samples, 2)};
nh.fa = cosmo_slice(nh.fa, 1, 2, 'struct');
opt = struct();
opt.partitions = cosmo_nfold_partitioner(ds);
opt.output = 'winner_predictions';
opt.progress = false;
x = cosmo_naive_bayes_classifier_searchlight(ds, nh, opt);
assertEqual(x.fa, nh.fa);
assertEqual(x.a, nh.a);
% compare with standard searchlight
assert_same_output_as_classifical_searchlight(ds, nh, opt);
opt.output = 'accuracy';
xacc = cosmo_naive_bayes_classifier_searchlight(ds, nh, opt);
assertEqual(xacc.samples, mean(bsxfun(@eq, x.samples, x.sa.targets)));
assert_same_output_as_classifical_searchlight(ds, nh, opt);
function test_naive_bayes_classifier_searchlight_multiple_pred
ds = cosmo_synthetic_dataset('ntargets', 25, 'nchunks', 5, 'size', 'small');
nh = cosmo_spherical_neighborhood(ds, 'radius', 1, 'progress', false);
opt = struct();
% multiple predictions per fold
opt.partitions = cosmo_nchoosek_partitioner(ds, 2);
opt.output = 'winner_predictions';
opt.progress = false;
x = cosmo_naive_bayes_classifier_searchlight(ds, nh, opt);
assertEqual(x.fa, nh.fa);
assertEqual(x.a, nh.a);
% compare with standard searchlight
assert_same_output_as_classifical_searchlight(ds, nh, opt);
opt.output = 'accuracy';
assert_same_output_as_classifical_searchlight(ds, nh, opt);
function assert_same_output_as_classifical_searchlight(ds, nh, opt)
opt.progress = false;
x = cosmo_naive_bayes_classifier_searchlight(ds, nh, opt);
opt.classifier = @cosmo_classify_naive_bayes;
y = cosmo_searchlight(ds, nh, @cosmo_crossvalidation_measure, opt);
sx = x.samples;
sy = y.samples;
assertElementsAlmostEqual(sx, sy);
x = rmfield(x, 'samples');
y = rmfield(y, 'samples');
assertEqual(x, y);
function test_naive_bayes_classifier_searchlight_exceptions
aet = @(varargin)assertExceptionThrown(@() ...
cosmo_naive_bayes_classifier_searchlight( ...
varargin{:}), '');
ds = cosmo_synthetic_dataset('size', 'small', 'nchunks', 4);
nh = cosmo_spherical_neighborhood(ds, 'radius', 1, 'progress', false);
opt = struct();
opt.progress = false;
opt.partitions = cosmo_nchoosek_partitioner(ds, 2);
opt.output = 'foo';
aet(ds, nh, opt);
% missing samples, so illegal partitions
opt.output = 'winner_predictions';
ds = cosmo_slice(ds, ds.sa.chunks <= 2);
opt.partitions = cosmo_nfold_partitioner(ds);
aet(ds, nh, opt);
% unsupported output
opt.output = 'fold_predictions';
ds_bad = ds;
ds_bad = cosmo_slice(ds_bad, ds_bad.sa.chunks <= 2);
opt.partitions = cosmo_nfold_partitioner(ds_bad);
aet(ds, nh, opt);
function test_naive_bayes_classifier_searchlight_deprecations
ds = cosmo_synthetic_dataset('size', 'small', 'nchunks', 4);
nh = cosmo_spherical_neighborhood(ds, 'radius', 1, 'progress', false);
%
opt = struct();
opt.progress = false;
opt.partitions = cosmo_nchoosek_partitioner(ds, 2);
opt.output = 'predictions';
orig_warning_state = cosmo_warning();
cleaner = onCleanup(@()cosmo_warning(orig_warning_state));
cosmo_warning('reset');
cosmo_warning('off');
% no warnings
w = cosmo_warning();
assertEqual(numel(w.shown_warnings), 0);
% output='predictions' is deprecated, so expect a warning
cosmo_naive_bayes_classifier_searchlight(ds, nh, opt);
w = cosmo_warning();
assertEqual(numel(w.shown_warnings), 1);
function test_naive_bayes_classifier_searchlight_partial_partitions
nchunks = 4;
ds = cosmo_synthetic_dataset('ntargets', 5, 'nchunks', nchunks, ...
'size', 'small');
nh = cosmo_spherical_neighborhood(ds, 'radius', 1, 'progress', false);
partitions = cosmo_nfold_partitioner(ds);
nsamples = size(ds.samples, 1);
prediction_count = zeros(nsamples, 1);
for k = 1:nchunks
with_missing = partitions.test_indices{k}(2:end);
partitions.test_indices{k} = with_missing;
prediction_count(with_missing) = prediction_count(with_missing) + 1;
end
opt = struct();
opt.progress = false;
opt.partitions = partitions;
opt.output = 'winner_predictions';
res = cosmo_naive_bayes_classifier_searchlight(ds, nh, opt);
ds_sa = rmfield(ds.sa, 'chunks');
assertEqual(res.sa, ds_sa);
cosmo_check_dataset(res);
assert_same_output_as_classifical_searchlight(ds, nh, opt);