function test_suite = test_balance_partitions
% tests for cosmo_balance_partitions
%
% # For CoSMoMVPA's copyright information and license terms, #
% # see the COPYING file distributed with CoSMoMVPA. #
try % assignment of 'localfunctions' is necessary in Matlab >= 2016
test_functions = localfunctions();
catch % no problem; early Matlab versions can use initTestSuite fine
end
initTestSuite;
function test_balance_partitions_repeats
nchunks = 5;
nsamples = 200;
nclasses = 4;
[p, ds] = get_sample_data(nsamples, nchunks, nclasses);
opt_cell = {{ {'balance_test', false}, ...
{'balance_test', true}, ...
{} ...
}, ...
{ {'nrepeats', 1}, ...
{'nrepeats', 5}, ...
{} ...
}, ...
{ {'opt_as_struct', false}, ...
{'opt_as_struct', true} ...
}};
opt_prod = cosmo_cartprod(opt_cell);
defaults = struct();
defaults.nrepeats = 1;
defaults.balance_test = true;
for k = 1:size(opt_prod, 1)
opt = opt_prod(k, :);
opt_struct = cosmo_structjoin(opt);
opt_as_struct = opt_struct.opt_as_struct;
opt_struct = rmfield(opt_struct, 'opt_as_struct');
if opt_as_struct
args = opt_struct;
else
args = [fieldnames(opt_struct) struct2cell(opt_struct)]';
end
b = cosmo_balance_partitions(p, ds, args);
full_args = cosmo_structjoin(defaults, args);
nrep = full_args.nrepeats;
balance_test = full_args.balance_test;
assertEqual(numel(b.train_indices), nrep * nchunks);
assertEqual(numel(b.test_indices), nrep * nchunks);
assertEqual(fieldnames(b), {'train_indices'; 'test_indices'});
nfolds = numel(p.test_indices);
for j = 1:nfolds
pi = p.train_indices{j};
pt = ds.sa.targets(pi);
for k = 1:nrep
fold_i = (j - 1) * nrep + k;
bi = b.train_indices{fold_i};
bt = ds.sa.targets(bi);
h = histc(bt, 1:nclasses)';
assertTrue(all(min(histc(pt, 1:nclasses)) == h));
end
end
assert_partitions_ok(ds, b, balance_test);
assert_balanced_partitions_subset(p, b);
end
function assert_balanced_partitions_subset(unbal_partitions, bal_partitions)
% each training and test fold in bal_partitions must correspond to
% a fold in the original partitions
nsamples = max([cellfun(@max, unbal_partitions.train_indices) ...
cellfun(@max, unbal_partitions.test_indices)]);
unbal_nfolds = numel(unbal_partitions.train_indices);
% see which indices were used in each fold
msk_train = find_member(unbal_partitions, 'train_indices', nsamples);
msk_test = find_member(unbal_partitions, 'test_indices', nsamples);
unbal_was_used = false(unbal_nfolds, 1);
bal_nfolds = numel(bal_partitions.train_indices);
for fold_i = 1:bal_nfolds
bal_train = bal_partitions.train_indices{fold_i};
bal_test = bal_partitions.test_indices{fold_i};
candidate_msk = all(msk_train(:, bal_train), 2) & ...
all(msk_test(:, bal_test), 2);
assert(any(candidate_msk));
unbal_was_used(candidate_msk) = true;
end
assertEqual(unbal_was_used, true(unbal_nfolds, 1));
function msk = find_member(partitions, label, nsamples)
folds = partitions.(label);
nfolds = numel(folds);
msk = false(nfolds, nsamples);
for k = 1:nfolds
msk(k, folds{k}) = true;
end
function assert_partitions_ok(ds, partitions, balanced_test_indices)
assertEqual(sort(fieldnames(partitions)), sort({'train_indices'; ...
'test_indices'}));
nfolds = numel(partitions.train_indices);
assertEqual(numel(partitions.test_indices), nfolds);
for fold_i = 1:nfolds
assert_fold_balanced(ds, partitions, fold_i, 'train_indices');
if balanced_test_indices
assert_fold_balanced(ds, partitions, fold_i, 'test_indices');
end
assert_fold_no_double_dipping(ds, partitions, fold_i);
assert_fold_targets_match(ds, partitions, fold_i);
assert_fold_indices_unique(partitions, fold_i);
end
function assert_fold_no_double_dipping(ds, partitions, fold)
train_indices = partitions.train_indices;
test_indices = partitions.test_indices;
train_chunks = ds.sa.chunks(train_indices{fold});
test_chunks = ds.sa.chunks(test_indices{fold});
assert(isempty(intersect(train_chunks, test_chunks)));
function assert_fold_balanced(ds, partitions, fold, label)
all_indices = partitions.(label);
indices = all_indices{fold};
unq_targets = unique(ds.sa.targets);
targets = ds.sa.targets(indices);
assertEqual(unique(targets), unq_targets);
h = histc(targets, unq_targets);
assertEqual(h(1) * ones(size(h)), h);
function assert_fold_targets_match(ds, partitions, fold)
train_indices = partitions.train_indices{fold};
test_indices = partitions.test_indices{fold};
nsamples = size(ds.samples, 1);
assert_all_int_with_max(train_indices, nsamples);
assert_all_int_with_max(test_indices, nsamples);
train_targets = ds.sa.targets(train_indices);
test_targets = ds.sa.targets(test_indices);
assertEqual(unique(train_targets), unique(test_targets));
function assert_fold_indices_unique(partitions, fold)
train_indices = partitions.train_indices{fold};
test_indices = partitions.test_indices{fold};
assert(isequal(sort(train_indices), unique(train_indices)));
assert(isequal(sort(test_indices), unique(test_indices)));
function assert_all_int_with_max(indices, max_value)
assert(min(indices) >= 1);
assert(max(indices) <= max_value);
assert(all(round(indices) == indices));
function test_balance_partitions_nmin
nchunks = 5;
nsamples = 200;
nclasses = 4;
[p, ds] = get_sample_data(nsamples, nchunks, nclasses);
nmin = 8 + round(rand() * 4);
args = struct();
args.nmin = nmin;
args.balance_test = [false, true];
arg_prod = cosmo_cartprod(args);
for arg_i = 1:numel(arg_prod)
arg = arg_prod{arg_i};
b = cosmo_balance_partitions(p, ds, arg);
counter = zeros(nsamples, nchunks);
for j = 1:numel(b.train_indices)
bi = b.train_indices{j};
bj = b.test_indices{j};
ch = unique(ds.sa.chunks(bj));
assert(numel(ch) == 1);
if arg.balance_test
% no other indices
assertEqual(setdiff(bj, p.test_indices{ch}), zeros(0, 1));
else
assertEqual(sort(bj), p.test_indices{ch});
end
bt = ds.sa.targets(bi);
h = histc(bt, 1:nclasses);
assertEqual(ones(nclasses, 1) * h(1), h);
counter(bi, ch) = counter(bi, ch) + 1;
end
for k = 1:nchunks
msk = ds.sa.chunks ~= k;
assert(min(counter(msk, k)) >= nmin);
assert(all(counter(~msk, k) == 0));
end
assert_partitions_ok(ds, b, arg.balance_test);
assert_balanced_partitions_subset(p, b);
end
function test_balance_partitions_exceptions
ds = cosmo_synthetic_dataset();
p = cosmo_nfold_partitioner(ds);
aet = @(varargin)assertExceptionThrown(@() ...
cosmo_balance_partitions(varargin{:}), '');
aet(struct, struct);
aet(ds, p); % wrong order
aet(p, ds, 'nmin', 1, 'nrepeats', 1);
% create missing class
ds.sa.targets(1) = 4;
aet(p, ds);
% missing target
p.train_indices{1} = p.train_indices{1}([1 3]);
aet(p, ds);
% double dipping
p.train_indices{1} = p.train_indices{2};
aet(p, ds);
function test_sorted_indices()
warning_state = cosmo_warning();
cleaner = onCleanup(@()cosmo_warning(warning_state));
cosmo_warning('off');
nchunks = ceil(cosmo_rand() * 10 + 10);
ntargets = ceil(cosmo_rand() * 10 + 10);
ds = cosmo_synthetic_dataset('nchunks', nchunks, 'ntargets', ntargets);
nchunks = numel(unique(ds.sa.chunks));
partitions = struct();
partitions.train_indices = cell(nchunks, 1);
partitions.test_indices = cell(nchunks, 1);
% build partitions with unsorted indices
for k = 1:nchunks
rp = cosmo_randperm(nchunks);
train_msk = (ds.sa.chunks) == rp(1) | (ds.sa.chunks) == rp(2);
train_idx = find(train_msk);
test_idx = find(~train_msk);
rp_train = cosmo_randperm(numel(train_idx));
rp_test = cosmo_randperm(numel(test_idx));
partitions.train_indices{k} = train_idx(rp_train);
partitions.test_indices{k} = test_idx(rp_test);
end
% balance partitions
bal_partitions = cosmo_balance_partitions(partitions, ds);
% all partitions must be sorted
fns = {'train_indices', 'test_indices'};
for k = 1:nchunks
for j = 1:2
fn = fns{j};
idx = partitions.(fn){k};
bal_idx = bal_partitions.(fn){k};
% indices must be the same
assertEqual(sort(idx(:)), sort(bal_idx(:)));
% balanced partitions must be sorted
assertTrue(issorted(bal_idx));
end
end
function [p, ds] = get_sample_data(nsamples, nchunks, nclasses)
ds = struct();
ds.samples = (1:nsamples)';
ds.sa.targets = ceil(cosmo_rand(nsamples, 1) * nclasses);
ds.sa.chunks = ceil(cosmo_rand(nsamples, 1) * nchunks);
p = cosmo_nfold_partitioner(ds);