test crossvalidate

function test_suite = test_crossvalidate
% tests for test_crossvalidate
%
% #   For CoSMoMVPA's copyright information and license terms,   #
% #   see the COPYING file distributed with CoSMoMVPA.           #
    try % assignment of 'localfunctions' is necessary in Matlab >= 2016
        test_functions=localfunctions();
    catch % no problem; early Matlab versions can use initTestSuite fine
    end
    initTestSuite;


function test_crossvalidate_basics
    classifier=@cosmo_classify_nn;
    randint=@()ceil(rand()*5+5);

    ds=cosmo_synthetic_dataset('ntargets',randint(),...
                               'nchunks',randint(),...
                               'nreps',randint(),...
                               'seed',0);    % random data
    nsamples=size(ds.samples,1);
    nfolds=randint();

    partitions=struct();
    partitions.train_indices=cell(nfolds,1);
    partitions.test_indices=cell(nfolds,1);

    train_size=ceil(nsamples*(rand()*.5+.25));

    pred=NaN(nsamples,nfolds);

    for fold=1:nfolds
        all_idx=randperm(nsamples);
        train_idx=all_idx(1:train_size);
        test_idx=all_idx((train_size+1):end);

        partitions.train_indices{fold}=train_idx;
        partitions.test_indices{fold}=test_idx;

        pred(test_idx,fold)=classifier(ds.samples(train_idx,:),...
                                        ds.sa.targets(train_idx),...
                                        ds.samples(test_idx,:));
    end

    pred_msk=~isnan(pred);
    is_correct=bsxfun(@eq,ds.sa.targets,pred) & pred_msk;
    acc=sum(is_correct(:))/sum(pred_msk(:));

    opt=struct();
    opt.check_partitions=false;

    [res_pred,res_acc]=cosmo_crossvalidate(ds,classifier,partitions,opt);
    assertEqual(res_pred,pred);
    assertElementsAlmostEqual(res_acc,acc);