function test_suite = test_crossvalidate
% tests for test_crossvalidate
%
% # For CoSMoMVPA's copyright information and license terms, #
% # see the COPYING file distributed with CoSMoMVPA. #
try % assignment of 'localfunctions' is necessary in Matlab >= 2016
test_functions = localfunctions();
catch % no problem; early Matlab versions can use initTestSuite fine
end
initTestSuite;
function test_crossvalidate_basics
classifier = @cosmo_classify_nn;
randint = @()ceil(rand() * 5 + 5);
ds = cosmo_synthetic_dataset('ntargets', randint(), ...
'nchunks', randint(), ...
'nreps', randint(), ...
'seed', 0); % random data
nsamples = size(ds.samples, 1);
nfolds = randint();
partitions = struct();
partitions.train_indices = cell(nfolds, 1);
partitions.test_indices = cell(nfolds, 1);
train_size = ceil(nsamples * (rand() * .5 + .25));
pred = NaN(nsamples, nfolds);
for fold = 1:nfolds
all_idx = randperm(nsamples);
train_idx = all_idx(1:train_size);
test_idx = all_idx((train_size + 1):end);
partitions.train_indices{fold} = train_idx;
partitions.test_indices{fold} = test_idx;
pred(test_idx, fold) = classifier(ds.samples(train_idx, :), ...
ds.sa.targets(train_idx), ...
ds.samples(test_idx, :));
end
pred_msk = ~isnan(pred);
is_correct = bsxfun(@eq, ds.sa.targets, pred) & pred_msk;
acc = sum(is_correct(:)) / sum(pred_msk(:));
opt = struct();
opt.check_partitions = false;
[res_pred, res_acc] = cosmo_crossvalidate(ds, classifier, partitions, opt);
assertEqual(res_pred, pred);
assertElementsAlmostEqual(res_acc, acc);