test balance dataset

function test_suite = test_balance_dataset
% tests for cosmo_average_samples
%
% #   For CoSMoMVPA's copyright information and license terms,   #
% #   see the COPYING file distributed with CoSMoMVPA.           #
    try % assignment of 'localfunctions' is necessary in Matlab >= 2016
        test_functions=localfunctions();
    catch % no problem; early Matlab versions can use initTestSuite fine
    end
    initTestSuite;


function r=randint()
    r=ceil(rand()*10+10);


function test_balance_dataset_basics
    nclasses=randint();

    ds_cell=cell(nclasses,1);
    nreps=zeros(nclasses,1);
    classes=zeros(nclasses,1);

    class_id=0;
    for k=1:nclasses
        nrep=randint();
        class_id=class_id+randint();

        nreps(k)=nrep;
        classes(k)=class_id;
        ds_k=cosmo_synthetic_dataset('nchunks',1,...
                                    'ntargets',1,...
                                    'nreps',nrep,...
                                    'seed',0);
        ds_k.sa.targets(:)=class_id;

        ds_cell{k}=ds_k;
    end

    ds=cosmo_stack(ds_cell);
    nsamples=numel(ds.sa.chunks);
    ds.sa.chunks(:)=1:nsamples;
    ds.sa.samples(:,1)=1:nsamples;

    [unused,i]=sort(randn(nsamples,1));
    ds=cosmo_slice(ds,i);

    [balanced_ds,idxs,balanced_classes]=cosmo_balance_dataset(ds);

    assertEqual(unique(classes),balanced_classes);
    assertEqual(cosmo_slice(ds,idxs(:),1),balanced_ds);
    assertEqual(size(idxs),[min(nreps),nclasses]);

    for k=1:nclasses
        msk=balanced_ds.sa.targets==classes(k);

        % correct number of selected samples
        assertEqual(sum(msk),min(nreps));

        % all selected samples are different
        assertEqual(unique(balanced_ds.samples(msk,1)),...
                        sort(balanced_ds.samples(msk,1)));
    end






function test_balance_dataset_exceptions
    aet=@(varargin)assertExceptionThrown(@()...
                    cosmo_balance_dataset(varargin{:}),'');
    ds=cosmo_synthetic_dataset('ntargets',randint());
    nsamples=size(ds.samples,1);
    ds.sa.chunks(:)=1:nsamples;

    % this should be ok
    cosmo_balance_dataset(ds);

    % not a dataset
    aet(struct)

    % not all chunks unique raises an exception
    bad_ds=ds;
    bad_ds.sa.chunks(1)=bad_ds.sa.chunks(2);
    aet(bad_ds);

    % missing samples
    bad_ds=rmfield(ds,'samples');
    aet(bad_ds)

    % missing targets
    bad_ds=ds;
    bad_ds.sa=rmfield(bad_ds.sa,'targets');
    aet(bad_ds);

    % missing chunks
    bad_ds=ds;
    bad_ds.sa=rmfield(bad_ds.sa,'targets');
    aet(bad_ds);