test balance dataset

function test_suite = test_balance_dataset
    % tests for cosmo_average_samples
    %
    % #   For CoSMoMVPA's copyright information and license terms,   #
    % #   see the COPYING file distributed with CoSMoMVPA.           #
    try % assignment of 'localfunctions' is necessary in Matlab >= 2016
        test_functions = localfunctions();
    catch % no problem; early Matlab versions can use initTestSuite fine
    end
    initTestSuite;

function r = randint()
    r = ceil(rand() * 10 + 10);

function test_balance_dataset_basics
    nclasses = randint();

    ds_cell = cell(nclasses, 1);
    nreps = zeros(nclasses, 1);
    classes = zeros(nclasses, 1);

    class_id = 0;
    for k = 1:nclasses
        nrep = randint();
        class_id = class_id + randint();

        nreps(k) = nrep;
        classes(k) = class_id;
        ds_k = cosmo_synthetic_dataset('nchunks', 1, ...
                                       'ntargets', 1, ...
                                       'nreps', nrep, ...
                                       'seed', 0);
        ds_k.sa.targets(:) = class_id;

        ds_cell{k} = ds_k;
    end

    ds = cosmo_stack(ds_cell);
    nsamples = numel(ds.sa.chunks);
    ds.sa.chunks(:) = 1:nsamples;
    ds.sa.samples(:, 1) = 1:nsamples;

    [unused, i] = sort(randn(nsamples, 1));
    ds = cosmo_slice(ds, i);

    [balanced_ds, idxs, balanced_classes] = cosmo_balance_dataset(ds);

    assertEqual(unique(classes), balanced_classes);
    assertEqual(cosmo_slice(ds, idxs(:), 1), balanced_ds);
    assertEqual(size(idxs), [min(nreps), nclasses]);

    for k = 1:nclasses
        msk = balanced_ds.sa.targets == classes(k);

        % correct number of selected samples
        assertEqual(sum(msk), min(nreps));

        % all selected samples are different
        assertEqual(unique(balanced_ds.samples(msk, 1)), ...
                    sort(balanced_ds.samples(msk, 1)));
    end

function test_balance_dataset_exceptions
    aet = @(varargin)assertExceptionThrown(@() ...
                                           cosmo_balance_dataset(varargin{:}), '');
    ds = cosmo_synthetic_dataset('ntargets', randint());
    nsamples = size(ds.samples, 1);
    ds.sa.chunks(:) = 1:nsamples;

    % this should be ok
    cosmo_balance_dataset(ds);

    % not a dataset
    aet(struct);

    % not all chunks unique raises an exception
    bad_ds = ds;
    bad_ds.sa.chunks(1) = bad_ds.sa.chunks(2);
    aet(bad_ds);

    % missing samples
    bad_ds = rmfield(ds, 'samples');
    aet(bad_ds);

    % missing targets
    bad_ds = ds;
    bad_ds.sa = rmfield(bad_ds.sa, 'targets');
    aet(bad_ds);

    % missing chunks
    bad_ds = ds;
    bad_ds.sa = rmfield(bad_ds.sa, 'targets');
    aet(bad_ds);