test balance partitions

function test_suite = test_balance_partitions
    % tests for cosmo_balance_partitions
    %
    % #   For CoSMoMVPA's copyright information and license terms,   #
    % #   see the COPYING file distributed with CoSMoMVPA.           #
    try % assignment of 'localfunctions' is necessary in Matlab >= 2016
        test_functions = localfunctions();
    catch % no problem; early Matlab versions can use initTestSuite fine
    end
    initTestSuite;

function test_balance_partitions_repeats
    nchunks = 5;
    nsamples = 200;
    nclasses = 4;
    [p, ds] = get_sample_data(nsamples, nchunks, nclasses);

    opt_cell = {{ {'balance_test', false}, ...
                 {'balance_test', true}, ...
                 {} ...
                }, ...
                { {'nrepeats', 1}, ...
                 {'nrepeats', 5}, ...
                 {} ...
                }, ...
                { {'opt_as_struct', false}, ...
                 {'opt_as_struct', true} ...
                }};
    opt_prod = cosmo_cartprod(opt_cell);

    defaults = struct();
    defaults.nrepeats = 1;
    defaults.balance_test = true;

    for k = 1:size(opt_prod, 1)
        opt = opt_prod(k, :);
        opt_struct = cosmo_structjoin(opt);
        opt_as_struct = opt_struct.opt_as_struct;
        opt_struct = rmfield(opt_struct, 'opt_as_struct');
        if opt_as_struct
            args = opt_struct;
        else
            args = [fieldnames(opt_struct) struct2cell(opt_struct)]';
        end

        b = cosmo_balance_partitions(p, ds, args);

        full_args = cosmo_structjoin(defaults, args);
        nrep = full_args.nrepeats;
        balance_test = full_args.balance_test;

        assertEqual(numel(b.train_indices), nrep * nchunks);
        assertEqual(numel(b.test_indices), nrep * nchunks);
        assertEqual(fieldnames(b), {'train_indices'; 'test_indices'});

        nfolds = numel(p.test_indices);
        for j = 1:nfolds
            pi = p.train_indices{j};
            pt = ds.sa.targets(pi);

            for k = 1:nrep
                fold_i = (j - 1) * nrep + k;
                bi = b.train_indices{fold_i};
                bt = ds.sa.targets(bi);
                h = histc(bt, 1:nclasses)';
                assertTrue(all(min(histc(pt, 1:nclasses)) == h));
            end
        end

        assert_partitions_ok(ds, b, balance_test);
        assert_balanced_partitions_subset(p, b);
    end

function assert_balanced_partitions_subset(unbal_partitions, bal_partitions)
    % each training and test fold in bal_partitions must correspond to
    % a fold in the original partitions
    nsamples = max([cellfun(@max, unbal_partitions.train_indices) ...
                    cellfun(@max, unbal_partitions.test_indices)]);
    unbal_nfolds = numel(unbal_partitions.train_indices);

    % see which indices were used in each fold
    msk_train = find_member(unbal_partitions, 'train_indices', nsamples);
    msk_test = find_member(unbal_partitions, 'test_indices', nsamples);

    unbal_was_used = false(unbal_nfolds, 1);

    bal_nfolds = numel(bal_partitions.train_indices);
    for fold_i = 1:bal_nfolds
        bal_train = bal_partitions.train_indices{fold_i};
        bal_test = bal_partitions.test_indices{fold_i};

        candidate_msk = all(msk_train(:, bal_train), 2) & ...
                            all(msk_test(:, bal_test), 2);

        assert(any(candidate_msk));
        unbal_was_used(candidate_msk) = true;
    end

    assertEqual(unbal_was_used, true(unbal_nfolds, 1));

function msk = find_member(partitions, label, nsamples)
    folds = partitions.(label);
    nfolds = numel(folds);
    msk = false(nfolds, nsamples);
    for k = 1:nfolds
        msk(k, folds{k}) = true;
    end

function assert_partitions_ok(ds, partitions, balanced_test_indices)
    assertEqual(sort(fieldnames(partitions)), sort({'train_indices'; ...
                                                    'test_indices'}));
    nfolds = numel(partitions.train_indices);
    assertEqual(numel(partitions.test_indices), nfolds);

    for fold_i = 1:nfolds
        assert_fold_balanced(ds, partitions, fold_i, 'train_indices');
        if balanced_test_indices
            assert_fold_balanced(ds, partitions, fold_i, 'test_indices');
        end
        assert_fold_no_double_dipping(ds, partitions, fold_i);
        assert_fold_targets_match(ds, partitions, fold_i);
        assert_fold_indices_unique(partitions, fold_i);
    end

function assert_fold_no_double_dipping(ds, partitions, fold)
    train_indices = partitions.train_indices;
    test_indices = partitions.test_indices;

    train_chunks = ds.sa.chunks(train_indices{fold});
    test_chunks = ds.sa.chunks(test_indices{fold});

    assert(isempty(intersect(train_chunks, test_chunks)));

function assert_fold_balanced(ds, partitions, fold, label)
    all_indices = partitions.(label);
    indices = all_indices{fold};

    unq_targets = unique(ds.sa.targets);
    targets = ds.sa.targets(indices);

    assertEqual(unique(targets), unq_targets);
    h = histc(targets, unq_targets);
    assertEqual(h(1) * ones(size(h)), h);

function assert_fold_targets_match(ds, partitions, fold)
    train_indices = partitions.train_indices{fold};
    test_indices = partitions.test_indices{fold};

    nsamples = size(ds.samples, 1);
    assert_all_int_with_max(train_indices, nsamples);
    assert_all_int_with_max(test_indices, nsamples);

    train_targets = ds.sa.targets(train_indices);
    test_targets = ds.sa.targets(test_indices);
    assertEqual(unique(train_targets), unique(test_targets));

function assert_fold_indices_unique(partitions, fold)
    train_indices = partitions.train_indices{fold};
    test_indices = partitions.test_indices{fold};

    assert(isequal(sort(train_indices), unique(train_indices)));
    assert(isequal(sort(test_indices), unique(test_indices)));

function assert_all_int_with_max(indices, max_value)
    assert(min(indices) >= 1);
    assert(max(indices) <= max_value);
    assert(all(round(indices) == indices));

function test_balance_partitions_nmin
    nchunks = 5;
    nsamples = 200;
    nclasses = 4;
    [p, ds] = get_sample_data(nsamples, nchunks, nclasses);

    nmin = 8 + round(rand() * 4);
    args = struct();
    args.nmin = nmin;
    args.balance_test = [false, true];

    arg_prod = cosmo_cartprod(args);

    for arg_i = 1:numel(arg_prod)
        arg = arg_prod{arg_i};
        b = cosmo_balance_partitions(p, ds, arg);

        counter = zeros(nsamples, nchunks);

        for j = 1:numel(b.train_indices)
            bi = b.train_indices{j};
            bj = b.test_indices{j};

            ch = unique(ds.sa.chunks(bj));
            assert(numel(ch) == 1);

            if arg.balance_test
                % no other indices
                assertEqual(setdiff(bj, p.test_indices{ch}), zeros(0, 1));
            else
                assertEqual(sort(bj), p.test_indices{ch});
            end

            bt = ds.sa.targets(bi);

            h = histc(bt, 1:nclasses);
            assertEqual(ones(nclasses, 1) * h(1), h);

            counter(bi, ch) = counter(bi, ch) + 1;
        end

        for k = 1:nchunks
            msk = ds.sa.chunks ~= k;
            assert(min(counter(msk, k)) >= nmin);
            assert(all(counter(~msk, k) == 0));
        end

        assert_partitions_ok(ds, b, arg.balance_test);
        assert_balanced_partitions_subset(p, b);
    end

function test_balance_partitions_exceptions

    ds = cosmo_synthetic_dataset();
    p = cosmo_nfold_partitioner(ds);
    aet = @(varargin)assertExceptionThrown(@() ...
                                           cosmo_balance_partitions(varargin{:}), '');

    aet(struct, struct);
    aet(ds, p); % wrong order

    aet(p, ds, 'nmin', 1, 'nrepeats', 1);

    % create missing class
    ds.sa.targets(1) = 4;
    aet(p, ds);

    % missing target
    p.train_indices{1} = p.train_indices{1}([1 3]);
    aet(p, ds);

    % double dipping
    p.train_indices{1} = p.train_indices{2};
    aet(p, ds);

function test_sorted_indices()
    warning_state = cosmo_warning();
    cleaner = onCleanup(@()cosmo_warning(warning_state));
    cosmo_warning('off');

    nchunks = ceil(cosmo_rand() * 10 + 10);
    ntargets = ceil(cosmo_rand() * 10 + 10);
    ds = cosmo_synthetic_dataset('nchunks', nchunks, 'ntargets', ntargets);

    nchunks = numel(unique(ds.sa.chunks));

    partitions = struct();
    partitions.train_indices = cell(nchunks, 1);
    partitions.test_indices = cell(nchunks, 1);

    % build partitions with unsorted indices
    for k = 1:nchunks
        rp = cosmo_randperm(nchunks);

        train_msk = (ds.sa.chunks) == rp(1) | (ds.sa.chunks) == rp(2);
        train_idx = find(train_msk);
        test_idx = find(~train_msk);

        rp_train = cosmo_randperm(numel(train_idx));
        rp_test = cosmo_randperm(numel(test_idx));

        partitions.train_indices{k} = train_idx(rp_train);
        partitions.test_indices{k} = test_idx(rp_test);
    end

    % balance partitions
    bal_partitions = cosmo_balance_partitions(partitions, ds);

    % all partitions must be sorted

    fns = {'train_indices', 'test_indices'};
    for k = 1:nchunks
        for j = 1:2
            fn = fns{j};
            idx = partitions.(fn){k};
            bal_idx = bal_partitions.(fn){k};

            % indices must be the same
            assertEqual(sort(idx(:)), sort(bal_idx(:)));

            % balanced partitions must be sorted
            assertTrue(issorted(bal_idx));
        end
    end

function [p, ds] = get_sample_data(nsamples, nchunks, nclasses)
    ds = struct();
    ds.samples = (1:nsamples)';
    ds.sa.targets = ceil(cosmo_rand(nsamples, 1) * nclasses);
    ds.sa.chunks = ceil(cosmo_rand(nsamples, 1) * nchunks);

    p = cosmo_nfold_partitioner(ds);