function test_suite=test_balance_partitions
% tests for cosmo_balance_partitions
%
% # For CoSMoMVPA's copyright information and license terms, #
% # see the COPYING file distributed with CoSMoMVPA. #
try % assignment of 'localfunctions' is necessary in Matlab >= 2016
test_functions=localfunctions();
catch % no problem; early Matlab versions can use initTestSuite fine
end
initTestSuite;
function test_balance_partitions_repeats
nchunks=5;
nsamples=200;
nclasses=4;
[p,ds]=get_sample_data(nsamples,nchunks,nclasses);
opt_cell={{ {'balance_test',false},...
{'balance_test',true},...
{},...
},...
{ {'nrepeats',1},...
{'nrepeats',5},...
{}...
},...
{ {'opt_as_struct',false},...
{'opt_as_struct',true},...
}};
opt_prod=cosmo_cartprod(opt_cell);
defaults=struct();
defaults.nrepeats=1;
defaults.balance_test=true;
for k=1:size(opt_prod,1)
opt=opt_prod(k,:);
opt_struct=cosmo_structjoin(opt);
opt_as_struct=opt_struct.opt_as_struct;
opt_struct=rmfield(opt_struct,'opt_as_struct');
if opt_as_struct;
args=opt_struct;
else
args=[fieldnames(opt_struct) struct2cell(opt_struct)]';
end
b=cosmo_balance_partitions(p,ds,args);
full_args=cosmo_structjoin(defaults,args);
nrep=full_args.nrepeats;
balance_test=full_args.balance_test;
assertEqual(numel(b.train_indices),nrep*nchunks);
assertEqual(numel(b.test_indices),nrep*nchunks);
assertEqual(fieldnames(b),{'train_indices';'test_indices'});
nfolds=numel(p.test_indices);
for j=1:nfolds
pi=p.train_indices{j};
pt=ds.sa.targets(pi);
for k=1:nrep
fold_i=(j-1)*nrep+k;
bi=b.train_indices{fold_i};
bt=ds.sa.targets(bi);
h=histc(bt,1:nclasses)';
assertTrue(all(min(histc(pt,1:nclasses))==h));
end
end
assert_partitions_ok(ds,b,balance_test);
assert_balanced_partitions_subset(p,b);
end
function assert_balanced_partitions_subset(unbal_partitions,bal_partitions)
% each training and test fold in bal_partitions must correspond to
% a fold in the original partitions
nsamples=max([cellfun(@max,unbal_partitions.train_indices) ...
cellfun(@max,unbal_partitions.test_indices)]);
unbal_nfolds=numel(unbal_partitions.train_indices);
% see which indices were used in each fold
msk_train=find_member(unbal_partitions,'train_indices',nsamples);
msk_test=find_member(unbal_partitions,'test_indices',nsamples);
unbal_was_used=false(unbal_nfolds,1);
bal_nfolds=numel(bal_partitions.train_indices);
for fold_i=1:bal_nfolds
bal_train=bal_partitions.train_indices{fold_i};
bal_test=bal_partitions.test_indices{fold_i};
candidate_msk=all(msk_train(:,bal_train),2) & ...
all(msk_test(:,bal_test),2);
assert(any(candidate_msk));
unbal_was_used(candidate_msk)=true;
end
assertEqual(unbal_was_used,true(unbal_nfolds,1));
function msk=find_member(partitions, label, nsamples)
folds=partitions.(label);
nfolds=numel(folds);
msk=false(nfolds,nsamples);
for k=1:nfolds
msk(k,folds{k})=true;
end
function assert_partitions_ok(ds, partitions, balanced_test_indices)
assertEqual(sort(fieldnames(partitions)),sort({'train_indices';...
'test_indices'}));
nfolds=numel(partitions.train_indices);
assertEqual(numel(partitions.test_indices),nfolds);
for fold_i=1:nfolds
assert_fold_balanced(ds,partitions,fold_i, 'train_indices');
if balanced_test_indices
assert_fold_balanced(ds,partitions,fold_i, 'test_indices');
end
assert_fold_no_double_dipping(ds,partitions,fold_i);
assert_fold_targets_match(ds,partitions,fold_i);
assert_fold_indices_unique(partitions,fold_i);
end
function assert_fold_no_double_dipping(ds, partitions, fold)
train_indices=partitions.train_indices;
test_indices=partitions.test_indices;
train_chunks=ds.sa.chunks(train_indices{fold});
test_chunks=ds.sa.chunks(test_indices{fold});
assert(isempty(intersect(train_chunks,test_chunks)));
function assert_fold_balanced(ds, partitions, fold, label)
all_indices=partitions.(label);
indices=all_indices{fold};
unq_targets=unique(ds.sa.targets);
targets=ds.sa.targets(indices);
assertEqual(unique(targets),unq_targets);
h=histc(targets,unq_targets);
assertEqual(h(1)*ones(size(h)),h);
function assert_fold_targets_match(ds,partitions,fold)
train_indices=partitions.train_indices{fold};
test_indices=partitions.test_indices{fold};
nsamples=size(ds.samples,1);
assert_all_int_with_max(train_indices,nsamples);
assert_all_int_with_max(test_indices,nsamples);
train_targets=ds.sa.targets(train_indices);
test_targets=ds.sa.targets(test_indices);
assertEqual(unique(train_targets),unique(test_targets));
function assert_fold_indices_unique(partitions,fold)
train_indices=partitions.train_indices{fold};
test_indices=partitions.test_indices{fold};
assert(isequal(sort(train_indices),unique(train_indices)));
assert(isequal(sort(test_indices),unique(test_indices)));
function assert_all_int_with_max(indices,max_value)
assert(min(indices)>=1);
assert(max(indices)<=max_value);
assert(all(round(indices)==indices));
function test_balance_partitions_nmin
nchunks=5;
nsamples=200;
nclasses=4;
[p,ds]=get_sample_data(nsamples,nchunks,nclasses);
nmin=8+round(rand()*4);
args=struct();
args.nmin=nmin;
args.balance_test=[false,true];
arg_prod=cosmo_cartprod(args);
for arg_i=1:numel(arg_prod)
arg=arg_prod{arg_i};
b=cosmo_balance_partitions(p,ds,arg);
counter=zeros(nsamples,nchunks);
for j=1:numel(b.train_indices)
bi=b.train_indices{j};
bj=b.test_indices{j};
ch=unique(ds.sa.chunks(bj));
assert(numel(ch)==1);
if arg.balance_test
% no other indices
assertEqual(setdiff(bj,p.test_indices{ch}),zeros(0,1));
else
assertEqual(sort(bj),p.test_indices{ch});
end
bt=ds.sa.targets(bi);
h=histc(bt,1:nclasses);
assertEqual(ones(nclasses,1)*h(1),h);
counter(bi,ch)=counter(bi,ch)+1;
end
for k=1:nchunks
msk=ds.sa.chunks~=k;
assert(min(counter(msk,k))>=nmin);
assert(all(counter(~msk,k)==0));
end
assert_partitions_ok(ds,b,arg.balance_test);
assert_balanced_partitions_subset(p,b);
end
function test_balance_partitions_exceptions
ds=cosmo_synthetic_dataset();
p=cosmo_nfold_partitioner(ds);
aet=@(varargin)assertExceptionThrown(@()...
cosmo_balance_partitions(varargin{:}),'');
aet(struct,struct)
aet(ds,p); % wrong order
aet(p,ds,'nmin',1,'nrepeats',1);
% create missing class
ds.sa.targets(1)=4;
aet(p,ds);
% missing target
p.train_indices{1}=p.train_indices{1}([1 3]);
aet(p,ds);
% double dipping
p.train_indices{1}=p.train_indices{2};
aet(p,ds);
function test_sorted_indices()
warning_state=cosmo_warning();
cleaner=onCleanup(@()cosmo_warning(warning_state));
cosmo_warning('off');
nchunks=ceil(cosmo_rand()*10+10);
ntargets=ceil(cosmo_rand()*10+10);
ds=cosmo_synthetic_dataset('nchunks',nchunks,'ntargets',ntargets);
nchunks=numel(unique(ds.sa.chunks));
partitions=struct();
partitions.train_indices=cell(nchunks,1);
partitions.test_indices=cell(nchunks,1);
% build partitions with unsorted indices
for k=1:nchunks
rp=cosmo_randperm(nchunks);
train_msk=(ds.sa.chunks)==rp(1) | (ds.sa.chunks)==rp(2);
train_idx=find(train_msk);
test_idx=find(~train_msk);
rp_train=cosmo_randperm(numel(train_idx));
rp_test=cosmo_randperm(numel(test_idx));
partitions.train_indices{k}=train_idx(rp_train);
partitions.test_indices{k}=test_idx(rp_test);
end
% balance partitions
bal_partitions=cosmo_balance_partitions(partitions,ds);
% all partitions must be sorted
fns={'train_indices','test_indices'};
for k=1:nchunks
for j=1:2
fn=fns{j};
idx=partitions.(fn){k};
bal_idx=bal_partitions.(fn){k};
% indices must be the same
assertEqual(sort(idx(:)),sort(bal_idx(:)));
% balanced partitions must be sorted
assertTrue(issorted(bal_idx));
end
end
function [p,ds]=get_sample_data(nsamples,nchunks,nclasses)
ds=struct();
ds.samples=(1:nsamples)';
ds.sa.targets=ceil(cosmo_rand(nsamples,1)*nclasses);
ds.sa.chunks=ceil(cosmo_rand(nsamples,1)*nchunks);
p=cosmo_nfold_partitioner(ds);