test check partitions

function test_suite = test_check_partitions()
    % tests for cosmo_check_partitions
    %
    % #   For CoSMoMVPA's copyright information and license terms,   #
    % #   see the COPYING file distributed with CoSMoMVPA.           #
    try % assignment of 'localfunctions' is necessary in Matlab >= 2016
        test_functions = localfunctions();
    catch % no problem; early Matlab versions can use initTestSuite fine
    end
    initTestSuite;

function test_check_partitions_exceptions()
    aet = @(varargin)assertExceptionThrown(@() ...
                                           cosmo_check_partitions(varargin{:}), '');
    is_ok = @(varargin)cosmo_check_partitions(varargin{:});
    ds = cosmo_synthetic_dataset();

    % empty input
    p = struct();
    aet(p, ds);
    p.train_indices = [];
    p.test_indices = [];
    aet(p, ds);

    % fold count mismatch
    p.train_indices = {[1 2], [1 2]};
    p.test_indices = {1};
    aet(p, ds);

    % unbalance in test indices is ok
    p.test_indices = {[3 4 5 6], [3 4 6]};
    is_ok(p, ds);

    % error for unbalance in unique targets, unless overridden
    p.train_indices = {[1 2], 1};
    aet(p, ds);
    aet(p, ds, 'unbalanced_partitions_ok', false);
    is_ok(p, ds, 'unbalanced_partitions_ok', true);

    % error for unbalance over chunks, unless overridden
    p.train_indices = {[1 2], [1 2 3]};
    p.test_indices = {[5 6], [5 6]};
    aet(p, ds);
    aet(p, ds, 'unbalanced_partitions_ok', false);
    is_ok(p, ds, 'unbalanced_partitions_ok', true);

    % indices must be integers not exceeding range
    p.train_indices = {[1 2], [4 7]};
    p.test_indices = {[3 4 5 6], [3 4 6]};
    aet(p, ds);
    p.train_indices = {[1 2], [4.5 5.5]};
    aet(p, ds);

    % empty indices are not allowed
    p.train_indices = {[1 2], []};
    aet(p, ds);

    % no double dipping allowed
    p.train_indices = {[1 2], [3 4]};
    aet(p, ds);
    aet(p, ds, 'unbalanced_partitions_ok', false);
    aet(p, ds, 'unbalanced_partitions_ok', true);

    % second input must be dataset struct
    ds.sa = struct();
    aet(p, ds);
    ds = struct();
    aet(p, ds);

    % it's fine to have missing targets...
    ds = cosmo_synthetic_dataset('ntargets', 3);
    p = struct();
    p.train_indices = {[2 3 5 6], [5 6 8 9]};
    p.test_indices = {[8 9], [1 3]};
    is_ok(p, ds);

    % ...but if so it must be consistent across the folds
    p.train_indices = {[1 2 4 5], [5 6 8 9]};
    p.test_indices = {[8 9], [1 3]};
    aet(p, ds);

function test_warning_shown_unsorted_indices()
    orig_state = cosmo_warning();
    cleaner = onCleanup(@()cosmo_warning(orig_state));

    cosmo_warning('reset');
    cosmo_warning('off');

    ds = cosmo_synthetic_dataset();

    max_chunk = max(ds.sa.chunks);
    sorted_partitions = struct();
    sorted_partitions.train_indices = {find(ds.sa.chunks < max_chunk)};
    sorted_partitions.test_indices = {find(ds.sa.chunks == max_chunk)};

    fns = {'train_indices', 'test_indices'};
    prev_warning_count = 0; % because of reset
    for k = 0:2
        switch k
            case 0
                partitions = sorted_partitions;
            otherwise
                partitions = sorted_partitions;
                fn = fns{k};

                reversed_idx = partitions.(fn){1}(end:-1:1);
                partitions.(fn){1} = reversed_idx;
                assertFalse(issorted(reversed_idx));
        end

        cosmo_check_partitions(partitions, ds);

        state = cosmo_warning();
        warning_count = numel(state.shown_warnings);

        should_have_new_warning = k > 0;
        if should_have_new_warning
            delta = 1;
        else
            delta = 0;
        end

        assertEqual(warning_count, prev_warning_count + delta);

        prev_warning_count = warning_count;
    end