test oddeven partitioner

function test_suite = test_oddeven_partitioner
    % tests for cosmo_oddeven_partitioner
    %
    % #   For CoSMoMVPA's copyright information and license terms,   #
    % #   see the COPYING file distributed with CoSMoMVPA.           #
    try % assignment of 'localfunctions' is necessary in Matlab >= 2016
        test_functions = localfunctions();
    catch % no problem; early Matlab versions can use initTestSuite fine
    end
    initTestSuite;

function test_oddeven_partitioner_basics
    warning_state = cosmo_warning();
    cleaner = onCleanup(@()cosmo_warning(warning_state));
    cosmo_warning('off');

    for nchunks = [2 6 8]
        ds = cosmo_synthetic_dataset('ntargets', 3, 'nchunks', nchunks);

        nsamples = size(ds.samples, 1);
        rp = randperm(nsamples);
        ds = cosmo_slice(ds, [rp rp]);

        unq_chunks = unique(ds.sa.chunks);
        msk_odd = cosmo_match(ds.sa.chunks, unq_chunks(1:2:end));

        idx_odd = find(msk_odd);
        idx_even = find(~msk_odd);

        fp = struct();
        fp.train_indices = {idx_odd, idx_even};
        fp.test_indices = {idx_even, idx_odd};
        assert_partitions_equal(fp, cosmo_oddeven_partitioner(ds, 'full'));
        % c=ds.sa.chunks;
        % assert_partitions_equal(fp,cosmo_oddeven_partitioner(c,'full'));

        hp = struct();
        hp.train_indices = {idx_odd};
        hp.test_indices = {idx_even};
        assert_partitions_equal(hp, cosmo_oddeven_partitioner(ds, 'half'));
        % c=ds.sa.chunks;
        % assert_partitions_equal(hp,cosmo_oddeven_partitioner(c,'half'));
    end

function test_oddeven_partitioner_exceptions
    aet = @(varargin)assertExceptionThrown(@() ...
                                           cosmo_oddeven_partitioner(varargin{:}), '');
    ds = struct();
    aet(ds);
    ds.samples = zeros(3, 4);
    aet(ds);
    ds = cosmo_synthetic_dataset('nchunks', 1);
    aet(ds);

    ds = cosmo_synthetic_dataset('nchunks', 2);
    aet(ds, 'foo');

function assert_partitions_equal(p, q)
    expected_fieldnames = {'train_indices'; 'test_indices'};
    assertEqual(sort(fieldnames(p)), sort(expected_fieldnames));
    assertEqual(sort(fieldnames(p)), sort(fieldnames(q)));

    assert_cell_same_elements(p.train_indices, q.train_indices);
    assert_cell_same_elements(p.test_indices, q.test_indices);

function assert_cell_same_elements(p, q)
    assertEqual(size(p), size(q));
    n = numel(p);
    for k = 1:n
        assertEqual(sort(p{k}), sort(q{k}));
    end

function test_unbalanced_partitions()
    warning_state = cosmo_warning();
    cleaner = onCleanup(@()cosmo_warning(warning_state));
    cosmo_warning('off');

    ds = struct();
    for ntargets = 2:5
        for nchunks = 2:5
            nsamples = ntargets * nchunks;
            ds.samples = randn(nsamples, 1);
            ds.sa.chunks = repmat(1:nchunks, 1, ntargets)';
            ds.sa.targets = ceil((1:nsamples)' / nchunks);

            idxs = cosmo_randperm(numel(ds.sa.chunks));
            ds = cosmo_slice(ds, idxs);

            % should be ok
            cosmo_oddeven_partitioner(ds);

            idx = ceil(rand() * ntargets * nchunks);
            ds.sa.chunks(idx) = ds.sa.chunks(idx) + 1;
            assertExceptionThrown(@() cosmo_oddeven_partitioner(ds), '*');
        end
    end