test crossvalidation measure

function test_suite = test_crossvalidation_measure
% tests for cosmo_crossvalidation_measure
%
% #   For CoSMoMVPA's copyright information and license terms,   #
% #   see the COPYING file distributed with CoSMoMVPA.           #
    try % assignment of 'localfunctions' is necessary in Matlab >= 2016
        test_functions=localfunctions();
    catch % no problem; early Matlab versions can use initTestSuite fine
    end
    initTestSuite;


function test_crossvalidation_measure_regression
    ds=cosmo_synthetic_dataset('ntargets',6,'nchunks',4);
    ds.sa.targets=ds.sa.targets+10;
    ds.sa.chunks=ds.sa.chunks+20;

    opt=struct();
    opt.partitions=cosmo_nfold_partitioner(ds);
    opt.classifier=@cosmo_classify_lda;

    res=cosmo_crossvalidation_measure(ds,opt);
    assertElementsAlmostEqual(res.samples,0.6250);
    assertEqual(res.sa,cosmo_structjoin('labels',{'accuracy'}));

    opt.output='accuracy';
    res2=cosmo_crossvalidation_measure(ds,opt);
    assertEqual(res,res2);

    opt.output='winner_predictions';
    res3=cosmo_crossvalidation_measure(ds,opt);
    assertEqual(res3.samples,10+[1 2 3 4 5 5 4 6 2 4 2 6 ...
                                1 2 3 4 6 6 1 3 3 4 3 1]');
    assertEqual(res3.sa,rmfield(ds.sa,'chunks'));

    % use deprecated output options
    warning_state=cosmo_warning();
    warning_state_resetter=onCleanup(@()cosmo_warning(warning_state));
    cosmo_warning('off');
    opt.output='winner_predictions';
    res3a=cosmo_crossvalidation_measure(ds,opt);
    assertEqual(res3,res3a);


    opt.output='fold_accuracy';
    res4=cosmo_crossvalidation_measure(ds,opt);
    assertElementsAlmostEqual(res4.samples,[5 2 5 3]'/6);
    assertEqual(res4.sa.folds,(1:4)');

    % test different classifier
    opt.classifier=@cosmo_classify_nn;
    opt.partitions=cosmo_nfold_partitioner(ds);
    opt.output='winner_predictions';

    res6=cosmo_crossvalidation_measure(ds,opt);
    assertEqual(res6.samples,10+[1 2 3 1 5 6 4 6 5 4 6 6 6 ...
                                    2 3 4 2 5 1 2 3 4 3 1]');
    % test normalization option
    opt.normalization='zscore';
    res7=cosmo_crossvalidation_measure(ds,opt);
    assertEqual(res7.samples,10+[1 2 3 5 5 6 4 6 5 4 6 6 6 ...
                                    2 3 4 5 5 1 5 3 1 3 1]');

    % test with averaging samples
    opt=rmfield(opt,'normalization');
    opt.average_train_count=1;
    res8=cosmo_crossvalidation_measure(ds,opt);
    assertEqual(res8.samples,10+[1 2 3 1 5 6 4 6 5 4 6 6 6 ...
                                    2 3 4 2 5 1 2 3 4 3 1]');

    opt.average_train_count=2;
    opt.average_train_resamplings=5;
    res9=cosmo_crossvalidation_measure(ds,opt);
    assertEqual(res9.samples,10+[1 2 3 4 5 6 4 6 2 4 6 6 1 ...
                                    2 3 4 5 6 1 2 3 4 5 1]');


function test_fold_accuracy()
    randint=@()ceil(rand()*5)+5;

    ntargets=randint();
    ds=cosmo_synthetic_dataset('ntargets',ntargets,...
                                'nchunks',randint(),...
                                'nreps',randint());
    ds.samples(:)=randn(size(ds.samples));
    ds.sa.targets=ds.sa.targets+10;
    ds.sa.chunks=ds.sa.chunks+20;

    partitions=cosmo_nchoosek_partitioner(ds,3);

    opt=struct();
    opt.partitions=partitions;
    opt.classifier=@cosmo_classify_nn;
    opt.output='fold_accuracy';

    nfolds=numel(opt.partitions.train_indices);

    res=cosmo_crossvalidation_measure(ds,opt);
    assertEqual(size(res.samples),[nfolds,1]);
    assertEqual(size(res.sa.folds),[nfolds,1]);

    for fold=1:nfolds
        f_opt=opt;
        f_opt.partitions.train_indices=partitions.train_indices(fold);
        f_opt.partitions.test_indices=partitions.test_indices(fold);
        f_res=cosmo_crossvalidation_measure(ds,f_opt);
        assertElementsAlmostEqual(res.samples(fold),f_res.samples);
    end



function test_fold_predictions
    randint=@()ceil(rand()*4)+1;

    ntargets=randint();
    targets_offset=randint();
    nchunks=randint()+1;
    ds=cosmo_synthetic_dataset('ntargets',ntargets,...
                                'nchunks',nchunks,...
                                'nreps',randint());
    ds.samples(:)=randn(size(ds.samples));
    ds.sa.targets=ds.sa.targets+targets_offset;
    ds.sa.chunks=ds.sa.chunks+20;


    opt=struct();
    opt.partitions=cosmo_nchoosek_partitioner(ds,ceil(nchunks/2));
    opt.classifier=@cosmo_classify_nn;
    opt.output='fold_predictions';

    train_idx=opt.partitions.train_indices;
    test_idx=opt.partitions.test_indices;

    nfolds=numel(train_idx);
    nsamples=size(ds.samples,1);

    % using crossvalidation_measure
    res=cosmo_crossvalidation_measure(ds,opt);

    % using crossvalidate function
    [cv_pred,acc]=cosmo_crossvalidate(ds,opt.classifier,opt.partitions);

    visited=false(size(res.samples));
    for k=1:nfolds
        % test crossvalidation_measure
        msk=res.sa.folds==k;
        visited(msk)=true;
        pred=res.samples(msk,:);
        assertEqual(size(pred),[numel(test_idx{k}),1]);
        assertEqual(res.sa.targets(msk),ds.sa.targets(test_idx{k}));

        % compare with classifier output
        fold_pred=opt.classifier(ds.samples(train_idx{k},:),...
                         ds.sa.targets(train_idx{k}),...
                         ds.samples(test_idx{k},:));
        assertEqual(fold_pred,pred);

        % check comso_crossvalidate output
        nan_msk=true(nsamples,1);
        nan_msk(test_idx{k})=false;
        assertEqual(isnan(cv_pred(:,k)),nan_msk);
        assertEqual(cv_pred(~nan_msk,k),fold_pred);
    end
    assert(all(visited));

    % fields should only be targets and folds
    assertEqual(sort(fieldnames(res.sa)),...
                sort({'targets';'folds'}));

    % test accuracy
    pred_msk=~isnan(cv_pred);
    correct_pred=bsxfun(@eq,cv_pred,ds.sa.targets) & pred_msk;
    assertElementsAlmostEqual(acc,sum(correct_pred)/sum(pred_msk));

    % test with winner_predictions
    opt.output='winner_predictions';
    res=cosmo_crossvalidation_measure(ds,opt);
    assertEqual(size(res.samples),[nsamples,1]);

    for row=1:nsamples
        h=histc(cv_pred(row,:),(1:ntargets)+targets_offset);

        % predicted sample is a winner
        row_pred=res.samples(row)-targets_offset;
        assert(all(h<=h(row_pred)));

        % correct winner
        assert(h(row_pred)==max(h));
    end


function test_crossvalidation_measure_deprecations
    warning_state=cosmo_warning();
    state_resetter=onCleanup(@()cosmo_warning(warning_state));

    deprecated_outputs={'predictions','raw'};

    ds=cosmo_synthetic_dataset();
    opt=struct();
    opt.classifier=@cosmo_classify_nn;
    opt.partitions=cosmo_nfold_partitioner(ds);


    for i_output=1:numel(deprecated_outputs);
        cosmo_warning('reset');
        cosmo_warning('off');

        output=deprecated_outputs{i_output};
        opt.output=output;

        % run the measure
        cosmo_crossvalidation_measure(ds,opt);

        % must have shown a warning
        s=cosmo_warning();
        w=s.shown_warnings;
        assertTrue(numel(w)>=1, 'no warning was shown');
    end



function test_crossvalidation_measure_exceptions
    aet=@(varargin)assertExceptionThrown(@()...
                        cosmo_crossvalidation_measure(varargin{:}),'');
    bad_opt=struct();
    bad_opt.partitions=struct();
    bad_opt.classifier=@abs;
    aet(struct,bad_opt);

    ds=cosmo_synthetic_dataset();
    opt=struct();
    opt.partitions=cosmo_nfold_partitioner(ds);
    opt.classifier=@cosmo_classify_lda;

    aet(struct,opt)

    bad_opt=opt;
    bad_opt.output='foo';
    aet(ds,bad_opt);

    bad_opt=opt;
    bad_opt.output='accuracy_by_chunk'; % not supported anymore
    aet(ds,bad_opt);



function test_balanced_accuracy()
    nclasses=10;
    nchunks=20;
    ds=cosmo_synthetic_dataset('ntargets',nclasses,...
                                'nchunks',nchunks,...
                                'nreps',4);

    % shuffle targets, use random data - assume data is unbalanced
    % afterwards
    ds.samples=randn(size(ds.samples));

    nsamples=size(ds.samples,1);
    while true
        ds.sa.targets=ceil(rand(nsamples,1)*nclasses);
        ds.sa.chunks=ceil(rand(nsamples,1)*nchunks);

        h_t=histc(ds.sa.targets,1:nclasses);
        h_c=histc(ds.sa.chunks,1:nchunks);

        if numel(h_t)~=nclasses || ...
                numel(h_c)~=nchunks
            % classes or chunsk missing, regenerate
            continue;
        end

        if any(h_t~=nclasses) && any(h_c~=nchunks)
            % inbalance
            break;
        end
    end

    % keep subset of all partitions, so that there are missing predictions
    % for some of the samples
    partitions=cosmo_nfold_partitioner(ds);
    nkeep=ceil(.3*nchunks);
    partitions.train_indices=partitions.train_indices(1:nkeep);
    partitions.test_indices=partitions.test_indices(1:nkeep);

    % compute balanced accuracy
    opt=struct();
    opt.classifier=@cosmo_classify_nn;
    opt.partitions=partitions;


    % without check_partitions, an exception should be thrown as the
    % partitions are supposed to be unbalanced
    assertExceptionThrown(@()...
                        cosmo_check_partitions(partitions,ds),'');
    assertExceptionThrown(@()...
                        cosmo_crossvalidation_measure(ds,opt),'');
    opt.check_partitions=false;

    % compute accuracy
    opt.output='balanced_accuracy';
    ba_result=cosmo_crossvalidation_measure(ds,opt);

    opt.output='winner_predictions';
    pred_result=cosmo_crossvalidation_measure(ds,opt);

    opt.output='accuracy';
    acc_result=cosmo_crossvalidation_measure(ds,opt);

    % check fields
    result_cell={ba_result,acc_result};
    for k=1:numel(result_cell)
        result=result_cell{k};

        assertEqual(sort(fieldnames(result)),sort({'samples';'sa'}));
        assertEqual(fieldnames(result.sa),{'labels'});
    end

    assertEqual(ba_result.sa.labels,{'balanced_accuracy'});
    assertEqual(acc_result.sa.labels,{'accuracy'});


    % compute expected result for balanced accuracy
    [unused,unused,target_idx]=unique(ds.sa.targets);
    assert(max(target_idx)==nclasses);
    nfolds=numel(partitions.train_indices);

    correct_count=zeros(nfolds,nclasses);
    class_count=zeros(1,nclasses);

    all_pred=NaN(nsamples,1);

    for fold_i=1:nfolds
        tr_idx=partitions.train_indices{fold_i};
        te_idx=partitions.test_indices{fold_i};

        ds_tr=cosmo_slice(ds,tr_idx);
        ds_te=cosmo_slice(ds,te_idx);

        target_idx_te=target_idx(te_idx);

        pred=opt.classifier(ds_tr.samples,...
                            ds_tr.sa.targets,...
                            ds_te.samples);
        all_pred(te_idx)=pred;
        for class_i=1:nclasses
            target_msk=target_idx_te==class_i;
            is_correct=pred(target_msk)==ds_te.sa.targets(target_msk);
            correct_count(fold_i,class_i)=sum(is_correct);
            class_count(class_i)=class_count(class_i)+numel(is_correct);
        end
    end

    class_acc=bsxfun(@rdivide,sum(correct_count,1),class_count);

    % verify expected result for balanced accuracy
    assertElementsAlmostEqual(mean(class_acc),ba_result.samples);

    % verify expected result for predictions of each class
    assertEqual(pred_result.samples,all_pred);
    assertEqual(pred_result.sa.targets,ds.sa.targets);



function test_pca()
    ntargets=2;
    nchunks=5;

    nfeatures=ceil(rand()*10+10);
    nsamples=ntargets*nchunks*4*nfeatures;

    idxs=(1:nsamples)'-1;

    ds=struct();
    ds.samples=randn(nsamples,nfeatures);
    ds.sa.targets=mod(idxs,ntargets)+1;
    ds.sa.chunks=mod(floor(idxs/(ntargets*nchunks)),nchunks)+1;

    test_msk=ds.sa.chunks==nchunks;
    partitions=struct();
    partitions.train_indices={find(~test_msk)};
    partitions.test_indices={find(test_msk)};

    opt=struct();
    opt.partitions=partitions;
    opt.classifier=@cosmo_classify_lda;
    opt.output='winner_predictions';

    for count=[1 ceil(nfeatures/2) nfeatures ceil(rand()*nfeatures)]
        opt_count=opt;
        opt_count.pca_explained_count=count;
        helper_test_pca_count(ds,opt_count,count)
    end

    for ratio=[.1 .5 .9 1 rand()]
        opt_ratio=opt;
        opt_ratio.pca_explained_ratio=ratio;
        helper_test_pca_ratio(ds,opt_ratio,ratio)
    end

function helper_test_pca_count(ds,opt,count)
    pred_full=cosmo_crossvalidation_measure(ds,opt);

    % compute results manually
    [expected_pred,test_indices]=helper_pca_crossval_single_fold(ds,...
                                                        opt,count);
    % compare results
    assertEqual(expected_pred,...
                    pred_full.samples(test_indices));

function [pred,test_indices]=helper_pca_crossval_single_fold(ds,opt,count)
    partitions=opt.partitions;
    assert(numel(partitions.train_indices)==1);
    assert(numel(partitions.test_indices)==1);
    ds_train=cosmo_slice(ds,partitions.train_indices{1});
    [tr_pca,params]=cosmo_pca(ds_train.samples,count);

    test_indices=partitions.test_indices{1};
    ds_test=cosmo_slice(ds,test_indices);
    te_pca=bsxfun(@minus,ds_test.samples,params.mu)*params.coef;

    pred=opt.classifier(tr_pca,ds_train.sa.targets,te_pca);


function helper_test_pca_ratio(ds,opt,ratio)
    partitions=opt.partitions;
    assert(numel(partitions.train_indices)==1);
    ds_train=cosmo_slice(ds,partitions.train_indices{1});
    [unused,params]=cosmo_pca(ds_train.samples);

    count=find(cumsum(params.explained)>=ratio*100,1);
    if isempty(count)
        count=numel(params.explained);
    end

    % delegate to count helepr
    helper_test_pca_count(ds,opt,count)


function test_crossvalidation_measure_pca_exceptions
    aet=@(varargin)assertExceptionThrown(@()...
                cosmo_crossvalidation_measure(varargin{:}),'');
    ds=cosmo_synthetic_dataset();

    opt=struct();
    opt.classifier=@cosmo_classify_lda;
    opt.partitions=cosmo_nfold_partitioner(ds);

    % mutually exclusive parameters
    bad_opt=opt;
    bad_opt.pca_explained_count=2;
    bad_opt.pca_explained_ratio=.5;

    aet(ds,bad_opt);