function test_suite = test_crossvalidate
% tests for test_crossvalidate
%
% # For CoSMoMVPA's copyright information and license terms, #
% # see the COPYING file distributed with CoSMoMVPA. #
try % assignment of 'localfunctions' is necessary in Matlab >= 2016
test_functions=localfunctions();
catch % no problem; early Matlab versions can use initTestSuite fine
end
initTestSuite;
function test_crossvalidate_basics
classifier=@cosmo_classify_nn;
randint=@()ceil(rand()*5+5);
ds=cosmo_synthetic_dataset('ntargets',randint(),...
'nchunks',randint(),...
'nreps',randint(),...
'seed',0); % random data
nsamples=size(ds.samples,1);
nfolds=randint();
partitions=struct();
partitions.train_indices=cell(nfolds,1);
partitions.test_indices=cell(nfolds,1);
train_size=ceil(nsamples*(rand()*.5+.25));
pred=NaN(nsamples,nfolds);
for fold=1:nfolds
all_idx=randperm(nsamples);
train_idx=all_idx(1:train_size);
test_idx=all_idx((train_size+1):end);
partitions.train_indices{fold}=train_idx;
partitions.test_indices{fold}=test_idx;
pred(test_idx,fold)=classifier(ds.samples(train_idx,:),...
ds.sa.targets(train_idx),...
ds.samples(test_idx,:));
end
pred_msk=~isnan(pred);
is_correct=bsxfun(@eq,ds.sa.targets,pred) & pred_msk;
acc=sum(is_correct(:))/sum(pred_msk(:));
opt=struct();
opt.check_partitions=false;
[res_pred,res_acc]=cosmo_crossvalidate(ds,classifier,partitions,opt);
assertEqual(res_pred,pred);
assertElementsAlmostEqual(res_acc,acc);