test check partitions

function test_suite=test_check_partitions()
% tests for cosmo_check_partitions
%
% #   For CoSMoMVPA's copyright information and license terms,   #
% #   see the COPYING file distributed with CoSMoMVPA.           #
    try % assignment of 'localfunctions' is necessary in Matlab >= 2016
        test_functions=localfunctions();
    catch % no problem; early Matlab versions can use initTestSuite fine
    end
    initTestSuite;

function test_check_partitions_exceptions()
    aet=@(varargin)assertExceptionThrown(@()...
                        cosmo_check_partitions(varargin{:}),'');
    is_ok=@(varargin)cosmo_check_partitions(varargin{:});
    ds=cosmo_synthetic_dataset();

    % empty input
    p=struct();
    aet(p,ds);
    p.train_indices=[];
    p.test_indices=[];
    aet(p,ds);

    % fold count mismatch
    p.train_indices={[1 2],[1 2]};
    p.test_indices={1};
    aet(p,ds);

    % unbalance in test indices is ok
    p.test_indices={[3 4 5 6],[3 4 6]};
    is_ok(p,ds);

    % error for unbalance in unique targets, unless overridden
    p.train_indices={[1 2],1};
    aet(p,ds);
    aet(p,ds,'unbalanced_partitions_ok',false);
    is_ok(p,ds,'unbalanced_partitions_ok',true);

    % error for unbalance over chunks, unless overridden
    p.train_indices={[1 2],[1 2 3]};
    p.test_indices={[5 6],[5 6]};
    aet(p,ds);
    aet(p,ds,'unbalanced_partitions_ok',false);
    is_ok(p,ds,'unbalanced_partitions_ok',true);

    % indices must be integers not exceeding range
    p.train_indices={[1 2],[4 7]};
    p.test_indices={[3 4 5 6],[3 4 6]};
    aet(p,ds);
    p.train_indices={[1 2],[4.5 5.5]};
    aet(p,ds);

    % empty indices are not allowed
    p.train_indices={[1 2],[]};
    aet(p,ds);


    % no double dipping allowed
    p.train_indices={[1 2],[3 4]};
    aet(p,ds);
    aet(p,ds,'unbalanced_partitions_ok',false);
    aet(p,ds,'unbalanced_partitions_ok',true);

    % second input must be dataset struct
    ds.sa=struct();
    aet(p,ds);
    ds=struct();
    aet(p,ds);

    % it's fine to have missing targets...
    ds=cosmo_synthetic_dataset('ntargets',3);
    p=struct();
    p.train_indices={[2 3 5 6], [5 6 8 9]};
    p.test_indices={[8 9],[1 3]};
    is_ok(p,ds);

    % ...but if so it must be consistent across the folds
    p.train_indices={[1 2 4 5],[5 6 8 9]};
    p.test_indices={[8 9],[1 3]};
    aet(p,ds);


function test_warning_shown_unsorted_indices()
    orig_state=cosmo_warning();
    cleaner=onCleanup(@()cosmo_warning(orig_state));

    cosmo_warning('reset');
    cosmo_warning('off');


    ds=cosmo_synthetic_dataset();

    max_chunk=max(ds.sa.chunks);
    sorted_partitions=struct();
    sorted_partitions.train_indices={find(ds.sa.chunks<max_chunk)};
    sorted_partitions.test_indices={find(ds.sa.chunks==max_chunk)};

    fns={'train_indices','test_indices'};
    prev_warning_count=0; % because of reset
    for k=0:2
        switch k
            case 0
                partitions=sorted_partitions;
            otherwise
                partitions=sorted_partitions;
                fn=fns{k};

                reversed_idx=partitions.(fn){1}(end:-1:1);
                partitions.(fn){1}=reversed_idx;
                assertFalse(issorted(reversed_idx))
        end

        cosmo_check_partitions(partitions,ds);

        state=cosmo_warning();
        warning_count=numel(state.shown_warnings);

        should_have_new_warning=k>0;
        if should_have_new_warning
            delta=1;
        else
            delta=0;
        end

        assertEqual(warning_count,prev_warning_count+delta);

        prev_warning_count=warning_count;
    end