test partitions

function test_suite = test_partitions
    % tests for partitioning functions
    %
    % #   For CoSMoMVPA's copyright information and license terms,   #
    % #   see the COPYING file distributed with CoSMoMVPA.           #
    try % assignment of 'localfunctions' is necessary in Matlab >= 2016
        test_functions = localfunctions();
    catch % no problem; early Matlab versions can use initTestSuite fine
    end
    initTestSuite;

function test_nfold_partitioner()
    ds = cosmo_synthetic_dataset('nchunks', 5, 'ntargets', 4);

    p = cosmo_nfold_partitioner(ds);
    assertEqual(p, cosmo_nfold_partitioner(ds.sa.chunks));

    fns = {'train_indices'; 'test_indices'};
    assertEqual(fns, fieldnames(p));
    for k = 1:5
        test_indices = (k - 1) * 4 + (1:4);
        train_indices = setdiff(1:20, test_indices);
        for j = 1:2
            fn = fns{j};
            if j == 1
                v = train_indices;
            else
                v = test_indices;
            end
            w = p.(fn);
            assertEqual(w{k}, v');
        end
    end

function test_nchoosek_partitioner()
    ds = cosmo_synthetic_dataset('nchunks', 5, 'ntargets', 4);

    p = cosmo_nfold_partitioner(ds);
    q = cosmo_nchoosek_partitioner(ds, 1);
    assertEqual(p, q);

    q = cosmo_nchoosek_partitioner(ds, .2);
    assertEqual(p, q);

    p = cosmo_nchoosek_partitioner(ds, 3);
    q = cosmo_nchoosek_partitioner(ds, .6);

    assertEqual(p, q);
    assertFalse(isequal(p, cosmo_nchoosek_partitioner(ds, .4)));

    q2 = cosmo_nchoosek_partitioner(ds.sa.chunks, 3);
    assertEqual(p, q2);

    fns = {'train_indices'; 'test_indices'};
    for j = 1:2
        fn = fns{j};
        counts = zeros(20, 1);

        v = p.(fn);
        assertEqual(size(v), [1 10]);

        for k = 1:numel(v)
            w = v{k};
            counts(w) = counts(w) + 1;
        end
        assertEqual(counts, ones(20, 1) * j * 2 + 2);
    end

function test_nchoosek_partitioner_half()
    offsets = [0 floor(rand() * 10 + 20)];
    for offset = offsets
        for nchunks = 2:4:10

            ds = cosmo_synthetic_dataset('nchunks', nchunks, 'ntargets', 3);
            ds.sa.chunks = ds.sa.chunks + offset;

            p = cosmo_nchoosek_partitioner(ds, 'half');
            combis = nchoosek(1:nchunks, nchunks / 2);

            n = size(combis, 1);
            assertEqual(numel(p.train_indices), n / 2);
            assertEqual(numel(p.test_indices), n / 2);

            for k = 1:n / 2
                tr_idx = find(cosmo_match(ds.sa.chunks - offset, ...
                                          combis(k, :)));
                te_idx = find(cosmo_match(ds.sa.chunks - offset, ...
                                          combis(n - k + 1, :)));

                assertEqual(p.train_indices{k}, tr_idx);
                assertEqual(p.test_indices{k}, te_idx);
            end
        end
    end

function test_nchoosek_partitioner_grouping()
    for nchunks = [2 5]
        ds = cosmo_synthetic_dataset('nchunks', nchunks, 'ntargets', 6);
        ds.sa.modality = mod(ds.sa.targets, 2) + 1;
        ds.sa.targets = floor(ds.sa.targets / 2);

        for n_test = 1:(nchunks - 1)
            for moda_idx = 1:4

                if moda_idx == 3
                    modas = {1 2};
                    moda_arg = {1 2};
                elseif moda_idx == 4
                    modas = {1, 2};
                    moda_arg = [];
                else
                    modas = {moda_idx};
                    moda_arg = moda_idx;
                end

                n_moda = numel(modas);

                p = cosmo_nchoosek_partitioner(ds, n_test, 'modality', ...
                                               moda_arg);
                combis = nchoosek(1:nchunks, n_test);
                n_combi = size(combis, 1);

                n_folds = numel(p.train_indices);
                assertEqual(numel(p.test_indices), n_folds);
                assertEqual(n_folds, n_combi * n_moda);

                % each fold must be present exactly once
                visited_count = zeros(1, n_folds);
                for m = 1:n_moda
                    for j = 1:n_combi
                        tr_msk = ~cosmo_match(ds.sa.chunks, combis(j, :)) & ...
                                    ~cosmo_match(ds.sa.modality, modas{m});
                        te_msk = cosmo_match(ds.sa.chunks, combis(j, :)) & ...
                                    cosmo_match(ds.sa.modality, modas{m});
                        tr_idx = find_fold(p.train_indices, tr_msk);
                        te_idx = find_fold(p.test_indices, te_msk);
                        assertEqual(tr_idx, te_idx);
                        visited_count(tr_idx) = visited_count(tr_idx) + 1;
                    end
                end

                assertEqual(visited_count, ones(1, n_folds));

                % also possible with indices
                p2 = cosmo_nchoosek_partitioner(ds.sa.chunks, n_test, ...
                                                ds.sa.modality, moda_arg);
                assertEqual(p, p2);

            end
        end
    end

function pos = find_fold(folds, msk)
    idxs = find(msk);
    n = numel(folds);

    pos = [];
    for k = 1:n
        if isequal(sort(folds{k}(:)), sort(idxs(:)))

            % no duplicates
            assert(isempty(pos));
            pos = k;
        end
    end
    assert(~isempty(pos));

function assert_disjoint(vs, i, j)
    common = intersect(vs(i), vs(j));
    if ~isempty(common)
        assertFalse(true, sprintf('element in common: %d', common(1)));

    end

function test_nchoosek_partitioner_exceptions()
    aet = @(varargin)assertExceptionThrown(@() ...
                                           cosmo_nchoosek_partitioner(varargin{:}), '');
    for nchunks = 2:3
        ds = cosmo_synthetic_dataset('nchunks', nchunks, 'ntargets', 4);

        aet(ds, -1);
        aet(ds, 0);
        aet(ds, 1.01);
        aet(ds, [1 1]);
        aet(ds, .99);
        aet(struct, 1);
        aet(ds, 'foo');
        aet(ds, .5, 'foo');
        aet(ds, struct);
        aet(ds, 1, 1, 1);
        aet(ds.sa.chunks, 1, 'chunks', 1);
        aet(ds.sa.chunks, 1, 'chunks', 1, 'chunks');

        ds.sa.modality = 3; % size mismatch
        aet(ds, 1, 'modality', 1);

    end