test partitions

function test_suite = test_partitions
% tests for partitioning functions
%
% #   For CoSMoMVPA's copyright information and license terms,   #
% #   see the COPYING file distributed with CoSMoMVPA.           #
    try % assignment of 'localfunctions' is necessary in Matlab >= 2016
        test_functions=localfunctions();
    catch % no problem; early Matlab versions can use initTestSuite fine
    end
    initTestSuite;

function test_nfold_partitioner()
    ds=cosmo_synthetic_dataset('nchunks',5,'ntargets',4);

    p=cosmo_nfold_partitioner(ds);
    assertEqual(p, cosmo_nfold_partitioner(ds.sa.chunks));

    fns={'train_indices';'test_indices'};
    assertEqual(fns, fieldnames(p));
    for k=1:5
        test_indices=(k-1)*4+(1:4);
        train_indices=setdiff(1:20,test_indices);
        for j=1:2
            fn=fns{j};
            if j==1
                v=train_indices;
            else
                v=test_indices;
            end
            w=p.(fn);
            assertEqual(w{k}, v');
        end
    end

function test_nchoosek_partitioner()
    ds=cosmo_synthetic_dataset('nchunks',5,'ntargets',4);

    p=cosmo_nfold_partitioner(ds);
    q=cosmo_nchoosek_partitioner(ds,1);
    assertEqual(p,q);

    q=cosmo_nchoosek_partitioner(ds,.2);
    assertEqual(p,q);


    p=cosmo_nchoosek_partitioner(ds,3);
    q=cosmo_nchoosek_partitioner(ds,.6);

    assertEqual(p,q);
    assertFalse(isequal(p, cosmo_nchoosek_partitioner(ds,.4)));

    q2=cosmo_nchoosek_partitioner(ds.sa.chunks,3);
    assertEqual(p,q2);

    fns={'train_indices';'test_indices'};
    for j=1:2
        fn=fns{j};
        counts=zeros(20,1);

        v=p.(fn);
        assertEqual(size(v),[1 10]);

        for k=1:numel(v)
            w=v{k};
            counts(w)=counts(w)+1;
        end
        assertEqual(counts,ones(20,1)*j*2+2);
    end

function test_nchoosek_partitioner_half()
    offsets=[0 floor(rand()*10+20)];
    for offset=offsets
        for nchunks=2:4:10

            ds=cosmo_synthetic_dataset('nchunks',nchunks,'ntargets',3);
            ds.sa.chunks=ds.sa.chunks+offset;

            p=cosmo_nchoosek_partitioner(ds,'half');
            combis=nchoosek(1:nchunks,nchunks/2);

            n=size(combis,1);
            assertEqual(numel(p.train_indices),n/2);
            assertEqual(numel(p.test_indices),n/2);

            for k=1:n/2
                tr_idx=find(cosmo_match(ds.sa.chunks-offset,...
                                                        combis(k,:)));
                te_idx=find(cosmo_match(ds.sa.chunks-offset,...
                                                        combis(n-k+1,:)));

                assertEqual(p.train_indices{k},tr_idx);
                assertEqual(p.test_indices{k},te_idx);
            end
        end
    end



function test_nchoosek_partitioner_grouping()
    for nchunks=[2 5]
        ds=cosmo_synthetic_dataset('nchunks',nchunks,'ntargets',6);
        ds.sa.modality=mod(ds.sa.targets,2)+1;
        ds.sa.targets=floor(ds.sa.targets/2);

        for n_test=1:(nchunks-1)
            for moda_idx=1:4

                if moda_idx==3
                    modas={1 2};
                    moda_arg={1 2};
                elseif moda_idx==4
                    modas={1,2};
                    moda_arg=[];
                else
                    modas={moda_idx};
                    moda_arg=moda_idx;
                end

                n_moda=numel(modas);

                p=cosmo_nchoosek_partitioner(ds,n_test,'modality',...
                                                            moda_arg);
                combis=nchoosek(1:nchunks,n_test);
                n_combi=size(combis,1);

                n_folds=numel(p.train_indices);
                assertEqual(numel(p.test_indices),n_folds);
                assertEqual(n_folds,n_combi*n_moda);

                % each fold must be present exactly once
                visited_count=zeros(1,n_folds);
                for m=1:n_moda
                    for j=1:n_combi
                        tr_msk=~cosmo_match(ds.sa.chunks,combis(j,:)) & ...
                                    ~cosmo_match(ds.sa.modality,modas{m});
                        te_msk=cosmo_match(ds.sa.chunks,combis(j,:)) & ...
                                    cosmo_match(ds.sa.modality,modas{m});
                        tr_idx=find_fold(p.train_indices,tr_msk);
                        te_idx=find_fold(p.test_indices,te_msk);
                        assertEqual(tr_idx,te_idx);
                        visited_count(tr_idx)=visited_count(tr_idx)+1;
                    end
                end

                assertEqual(visited_count,ones(1,n_folds));

                % also possible with indices
                p2=cosmo_nchoosek_partitioner(ds.sa.chunks,n_test,...
                                            ds.sa.modality,moda_arg);
                assertEqual(p,p2);


            end
        end
    end

function pos=find_fold(folds, msk)
    idxs=find(msk);
    n=numel(folds);

    pos=[];
    for k=1:n
        if isequal(sort(folds{k}(:)),sort(idxs(:)))

            % no duplicates
            assert(isempty(pos));
            pos=k;
        end
    end
    assert(~isempty(pos));


function assert_disjoint(vs,i,j)
    common=intersect(vs(i),vs(j));
    if ~isempty(common)
        assertFalse(true,sprintf('element in common: %d', common(1)));

    end
function test_nchoosek_partitioner_exceptions()
    aet=@(varargin)assertExceptionThrown(@()...
                        cosmo_nchoosek_partitioner(varargin{:}),'');
    for nchunks=2:3
        ds=cosmo_synthetic_dataset('nchunks',nchunks,'ntargets',4);


        aet(ds,-1);
        aet(ds,0);
        aet(ds,1.01);
        aet(ds,[1 1]);
        aet(ds,.99);
        aet(struct,1);
        aet(ds,'foo');
        aet(ds,.5,'foo');
        aet(ds,struct);
        aet(ds,1,1,1);
        aet(ds.sa.chunks,1,'chunks',1);
        aet(ds.sa.chunks,1,'chunks',1,'chunks');

        ds.sa.modality=3; % size mismatch
        aet(ds,1,'modality',1);

    end