function test_suite = test_check_partitions()
% tests for cosmo_check_partitions
%
% # For CoSMoMVPA's copyright information and license terms, #
% # see the COPYING file distributed with CoSMoMVPA. #
try % assignment of 'localfunctions' is necessary in Matlab >= 2016
test_functions = localfunctions();
catch % no problem; early Matlab versions can use initTestSuite fine
end
initTestSuite;
function test_check_partitions_exceptions()
aet = @(varargin)assertExceptionThrown(@() ...
cosmo_check_partitions(varargin{:}), '');
is_ok = @(varargin)cosmo_check_partitions(varargin{:});
ds = cosmo_synthetic_dataset();
% empty input
p = struct();
aet(p, ds);
p.train_indices = [];
p.test_indices = [];
aet(p, ds);
% fold count mismatch
p.train_indices = {[1 2], [1 2]};
p.test_indices = {1};
aet(p, ds);
% unbalance in test indices is ok
p.test_indices = {[3 4 5 6], [3 4 6]};
is_ok(p, ds);
% error for unbalance in unique targets, unless overridden
p.train_indices = {[1 2], 1};
aet(p, ds);
aet(p, ds, 'unbalanced_partitions_ok', false);
is_ok(p, ds, 'unbalanced_partitions_ok', true);
% error for unbalance over chunks, unless overridden
p.train_indices = {[1 2], [1 2 3]};
p.test_indices = {[5 6], [5 6]};
aet(p, ds);
aet(p, ds, 'unbalanced_partitions_ok', false);
is_ok(p, ds, 'unbalanced_partitions_ok', true);
% indices must be integers not exceeding range
p.train_indices = {[1 2], [4 7]};
p.test_indices = {[3 4 5 6], [3 4 6]};
aet(p, ds);
p.train_indices = {[1 2], [4.5 5.5]};
aet(p, ds);
% empty indices are not allowed
p.train_indices = {[1 2], []};
aet(p, ds);
% no double dipping allowed
p.train_indices = {[1 2], [3 4]};
aet(p, ds);
aet(p, ds, 'unbalanced_partitions_ok', false);
aet(p, ds, 'unbalanced_partitions_ok', true);
% second input must be dataset struct
ds.sa = struct();
aet(p, ds);
ds = struct();
aet(p, ds);
% it's fine to have missing targets...
ds = cosmo_synthetic_dataset('ntargets', 3);
p = struct();
p.train_indices = {[2 3 5 6], [5 6 8 9]};
p.test_indices = {[8 9], [1 3]};
is_ok(p, ds);
% ...but if so it must be consistent across the folds
p.train_indices = {[1 2 4 5], [5 6 8 9]};
p.test_indices = {[8 9], [1 3]};
aet(p, ds);
function test_warning_shown_unsorted_indices()
orig_state = cosmo_warning();
cleaner = onCleanup(@()cosmo_warning(orig_state));
cosmo_warning('reset');
cosmo_warning('off');
ds = cosmo_synthetic_dataset();
max_chunk = max(ds.sa.chunks);
sorted_partitions = struct();
sorted_partitions.train_indices = {find(ds.sa.chunks < max_chunk)};
sorted_partitions.test_indices = {find(ds.sa.chunks == max_chunk)};
fns = {'train_indices', 'test_indices'};
prev_warning_count = 0; % because of reset
for k = 0:2
switch k
case 0
partitions = sorted_partitions;
otherwise
partitions = sorted_partitions;
fn = fns{k};
reversed_idx = partitions.(fn){1}(end:-1:1);
partitions.(fn){1} = reversed_idx;
assertFalse(issorted(reversed_idx));
end
cosmo_check_partitions(partitions, ds);
state = cosmo_warning();
warning_count = numel(state.shown_warnings);
should_have_new_warning = k > 0;
if should_have_new_warning
delta = 1;
else
delta = 0;
end
assertEqual(warning_count, prev_warning_count + delta);
prev_warning_count = warning_count;
end