test independent samples partitioner

function test_suite = test_independent_samples_partitioner
% tests for cosmo_independent_samples_partitioner
%
% #   For CoSMoMVPA's copyright information and license terms,   #
% #   see the COPYING file distributed with CoSMoMVPA.           #
    try % assignment of 'localfunctions' is necessary in Matlab >= 2016
        test_functions=localfunctions();
    catch % no problem; early Matlab versions can use initTestSuite fine
    end
    initTestSuite;

function test_independent_samples_partitioner_tiny()
    min_class_count=2;
    max_class_count=2+min_class_count;
    rng_class_count=[min_class_count,max_class_count];

    for nclasses=[2 3]
        for test_count=0:(nclasses+1)
            seed=ceil(rand()*1e6);
            opt=struct();
            opt.test_count=test_count;
            p1=helper_test_independent_samples_partitioner(nclasses,...
                                rng_class_count,opt,seed);

            opt=struct();
            opt.test_count=-test_count; % use test_ratio
            p2=helper_test_independent_samples_partitioner(nclasses,...
                                rng_class_count,opt,seed);

            assertEqual(p1,p2);
        end
    end


function test_independent_samples_partitioner_big()
    nclasses=ceil(rand()*4+10);
    for test_count=0:(nclasses+1)
        min_class_count=ceil(rand()*10+2);
        max_class_count=min_class_count+ceil(rand()*10+2);
        rng_class_count=[min_class_count,max_class_count];

        opt=struct();
        opt.test_count=test_count;
        opt.fold_count=ceil(rand()*100+10);

        p1=helper_test_independent_samples_partitioner(nclasses,...
                                rng_class_count,...
                                opt);

        opt=struct();
        opt.test_count=-test_count; % use test_ratio
        opt.fold_count=ceil(rand()*100+10);
        p2=helper_test_independent_samples_partitioner(nclasses,...
                                rng_class_count,...
                                opt);
    end

function test_independent_samples_partitioner_big_with_seed()
    opt=struct();
    opt.fold_count=100;
    opt.test_count=1;

    nclasses=ceil(rand()*4+10);
    min_class_count=ceil(rand()*10+2);
    max_class_count=min_class_count+ceil(rand()*10+2);
    rng_class_count=[min_class_count,max_class_count];

    ds_seed=(ceil(rand()*1e6));

    % without seed they use the default seed, must be equal
    p1=helper_test_independent_samples_partitioner(nclasses,...
                                rng_class_count,...
                                opt,ds_seed);
    p2=helper_test_independent_samples_partitioner(nclasses,...
                                rng_class_count,...
                                opt,ds_seed);
    assertEqual(p1,p2);
    opt.seed=1;
    p3=helper_test_independent_samples_partitioner(nclasses,...
                                rng_class_count,...
                                opt,ds_seed);
    assertEqual(p1,p3);
    opt.seed=2;
    p4=helper_test_independent_samples_partitioner(nclasses,...
                                rng_class_count,...
                                opt,ds_seed);
    assertFalse(isequal(p1,p4));

    % with no seed they must be identical
    opt.seed=0;
    p1=helper_test_independent_samples_partitioner(nclasses,...
                                rng_class_count,...
                                opt,ds_seed);
    p2=helper_test_independent_samples_partitioner(nclasses,...
                                rng_class_count,...
                                opt,ds_seed);

    assertFalse(isequal(p1,p2));




function p=helper_test_independent_samples_partitioner(nclasses,...
                            rng_class_count,opt,gen_seed)
    if nargin<4
        gen_seed=0;
    end


    ds=helper_generate_dataset(nclasses,...
                            rng_class_count,...
                            gen_seed);

    % compute how many samples per class in test set
    class_counts=histc(ds.sa.targets,1:nclasses);
    min_class_count=min(class_counts);
    max_class_count=max(class_counts);

    test_count=opt.test_count;
    if test_count<0
        % use test_ratio
        test_count=-test_count;
        opt=rmfield(opt,'test_count');
        opt.test_ratio=test_count/min_class_count;

    end

    train_count=min_class_count-test_count;

    has_illegal_count=train_count<=0 || test_count<=0;
    if has_illegal_count
        if ~isfield(opt,'fold_count')
            opt.fold_count=1;
        end
    else
        if isfield(opt,'fold_count')
            % given by calling funciton
            fold_count=opt.fold_count;
        else
            % use all folds available
            assert(nclasses<=3,'fold_count required with nclasses>3');
            assert(max_class_count<=5,'too many classes');
            % When not set we generate all possible folds; thereare at most
            % nchoosek(5,3)^nreps <= 20^3 = 8000 possible folds
            max_fold_count=8000; %

            % see how many possible folds based on each class
            combi_test=@(i)nchoosek(class_counts(i),test_count);
            test_fold_counts=arrayfun(combi_test,1:nclasses);
            combi_train=@(i)nchoosek(class_counts(i)-test_count,train_count);
            train_fold_counts=arrayfun(combi_train,1:nclasses);

            % take product
            fold_count=prod(test_fold_counts)*prod(train_fold_counts);
            assert(fold_count<=max_fold_count,'memory safety limit exceeded');
            opt.fold_count=fold_count;
        end
    end


    func=@()cosmo_independent_samples_partitioner(ds,opt);

    if has_illegal_count
        % not enough samples, expect an exception
        assertExceptionThrown(func,'')
        p=[];
        return
    end

    p=func();

    assertEqual(numel(p.train_indices),fold_count);
    assertEqual(numel(p.test_indices),fold_count);

    nsamples=size(ds.samples,1);

    train_count=min(class_counts)-test_count;

    for f=1:fold_count
        tr=p.train_indices{f};
        te=p.test_indices{f};
        assertTrue(isempty(intersect(tr,te)));

        assert_all_int_less_than(tr,nsamples);
        assert_all_int_less_than(te,nsamples);

        % equal number of targets in all classes, for train and test
        assert_all_hist_equal(ds.sa.targets(te),nclasses,test_count);
        assert_all_hist_equal(ds.sa.targets(tr),nclasses,train_count);
    end

    assert_all_folds_unique(p.train_indices,p.test_indices)


function ds=helper_generate_dataset(nclasses,rng_class_count,seed)
    % generate dataset where each target occurs at least min_class_count
    % times
    min_count=rng_class_count(1);
    max_count=rng_class_count(2);

    seed_arg={'seed',seed};

    delta=(max_count-min_count);
    assert(delta>0);

    % make lots of trials
    nregular=nclasses*min_count;
    nextra=ceil(delta*nclasses*cosmo_rand(seed_arg{:}));

    nsamples=nregular+nextra;

    ds=struct();
    ds.samples=rand(nsamples,2);
    ds.sa.targets=zeros(nsamples,1);
    rp=cosmo_randperm(nsamples,seed_arg{:});
    ds.sa.targets(1:nsamples)=mod(rp,nclasses)+1;
    ds.sa.chunks=(1:nsamples)';

    h=histc(ds.sa.targets,1:nclasses);
    assert(min(h)>=min_count)
    assert(max(h)<=max_count)



function assert_all_folds_unique(xs,ys)
    assert(all(min(cellfun(@min,xs))>=0));
    assert(all(min(cellfun(@min,ys))>=0));


    xs_max=max(cellfun(@max,xs));
    ys_max=max(cellfun(@max,ys));

    % value greater than max in all
    ys_mark=1+max(xs_max,ys_max);

    nfolds=numel(xs);
    merged=cell(nfolds,1);
    for f_i=1:nfolds

        xy=sort([xs{f_i}(:); (ys_mark+ys{f_i}(:))]);
        merged{f_i}=xy(:)';
    end

    % must all be same length
    c=cellfun(@numel,merged);
    assertEqual(c(1)+zeros(size(c)),c,'inputs do not have same length');

    % put in matrix
    merged_mat=cat(1,merged{:});
    s=sortrows(merged_mat);

    % look for duplicate rows
    eq_msk=bsxfun(@eq,s(1:(end-1),:),s(2:end,:));
    row_same=find(all(eq_msk,2),1);

    assertEqual(row_same,zeros(0,1),'row duplicate');



function assert_all_hist_equal(targets,nclasses,nreps)
    h_targets=histc(targets(:)',1:nclasses);
    assertEqual(h_targets,nreps+zeros(1,nclasses));


function assert_all_int_less_than(x,mx)
    assert(isnumeric(x));
    assertEqual(sort(x),x);
    assert(all(x>=1));
    assert(all(x<=mx));
    assert(all(round(x)==x));




function test_independent_samples_partitioner_mismatch_exceptions
    aet=@(varargin)assertExceptionThrown(@()...
                   cosmo_independent_samples_partitioner(varargin{:}),'');

    aet_targets=@(counts,varargin)aet(...
                 helper_generate_dataset_with_target_counts(counts{:}),...
                                                            varargin{:});


    opt=struct();
    opt.test_count=1;
    opt.fold_count=1;

    % missing target
    aet_targets({[3 3 3],[4 4]},opt);

    % not enought targets in one class
    aet_targets({[3 3 3],[2 2 1]},opt);


    opt.test_count=3;
    aet_targets({[4 3 4],[4 4 4]},opt);

    % try with ratio
    opt=rmfield(opt,'test_count');

    % missing target
    opt.test_ratio=.25;
    aet_targets({[10 10 10],[10 10]},opt);

    % too few targets
    aet_targets({[2 2 2],[4 4 4]},opt);

    % with too many folds
    aet_targets({[4 4 4],[4 4 4]},opt,'max_fold_count',0);



function test_independent_samples_partitioner_arg_exceptions
    aet=@(varargin)assertExceptionThrown(@()...
                   cosmo_independent_samples_partitioner(varargin{:}),'');
    ds=cosmo_synthetic_dataset();
    ds.sa.chunks=(1:size(ds.samples,1))';

    % missing arguments
    aet(ds);
    aet(ds,'fold_count');
    aet(ds,'fold_count',2);

    % mutually exclusive arguments
    aet(ds,'fold_count',2,'test_count',1,'test_ratio',.5);

    % not a dataset
    aet(struct,'fold_count',2,'test_count',1);

    % non-unique
    ds_bad=ds;
    ds_bad.sa.chunks(2)=ds.sa.chunks(1);
    aet(ds_bad,'fold_count',2,'test_count',1);


function ds=helper_generate_dataset_with_target_counts(varargin)
    nfeatures=2;
    nchunks=numel(varargin);
    ds_cell=cell(nchunks,1);
    for i=1:nchunks
        counts=varargin{i};
        nclasses=numel(counts);
        ds_parts=cell(nclasses,1);
        for j=1:nclasses
            nt=counts(j);
            ds=struct();

            ds.samples=randn(nt,nfeatures);
            ds.sa.targets=zeros(nt,1)+j;
            ds.sa.chunks=zeros(nt,1)+i;
            ds_parts{j}=ds;
        end

        ds_cell{i}=cosmo_stack(ds_parts);
    end

    ds=cosmo_stack(ds_cell);