function bal_partitions=cosmo_balance_partitions(partitions,ds, varargin)
% balances a partition so that each target occurs equally often in each
% training and test chunk
%
% bpartitions=cosmo_balance_partitions(partitions, ds, ...)
%
% Inputs:
% partitions struct with fields:
% .train_indices } Each is a 1xN cell (for N chunks) containing the
% .test_indices } sample indices for each partition
% ds dataset struct with field .sa.targets.
% 'nrepeats',nr Number of repeats (default: 1). The output will
% have nrep as many partitions as the input set. This
% option is not compatible with 'nmin'.
% 'nmin',nm Ensure that each sample occurs at least
% nmin times in each training set (some samples may
% be repeated more often than than). This option is not
% compatible with 'nrepeats'.
% 'balance_test' If set to false, indices in the test set are not
% necessarily balanced. The default is true.
% 'seed',sd Use seed sd for pseudoo-random number generation.
% Different values lead almost always to different
% pseudo-random orders. To disable using a seed - which
% causes this function to give different results upon
% subsequent calls with identical inputs - use sd=0.
%
% Ouput:
% bpartitions similar struct as input partitions, except that
% - each field is a 1x(N*nsets) cell
% - each unique target is represented about equally often
% - each target in each training chunk occurs equally
% often
%
% Examples:
% % generate a simple dataset with unbalanced partitions
% ds=struct();
% ds.samples=zeros(9,2);
% ds.sa.targets=[1 1 2 2 2 3 3 3 3]';
% ds.sa.chunks=[1 2 2 1 1 1 2 2 2]';
% p=cosmo_nfold_partitioner(ds);
% %
% % show original (unbalanced) partitioning
% cosmo_disp(p);
% %|| .train_indices
% %|| { [ 2 [ 1
% %|| 3 4
% %|| 7 5
% %|| 8 6 ]
% %|| 9 ] }
% %|| .test_indices
% %|| { [ 1 [ 2
% %|| 4 3
% %|| 5 7
% %|| 6 ] 8
% %|| 9 ] }
% %
% % make standard balancing (nsets=1); some targets are not used
% q=cosmo_balance_partitions(p,ds);
% cosmo_disp(q);
% %|| .train_indices
% %|| { [ 2 [ 1
% %|| 3 5
% %|| 7 ] 6 ] }
% %|| .test_indices
% %|| { [ 1 [ 2
% %|| 5 3
% %|| 6 ] 7 ] }
% %
% % make balancing where each sample in each training fold is used at
% % least once
% q=cosmo_balance_partitions(p,ds,'nmin',1);
% cosmo_disp(q);
% %|| .train_indices
% %|| { [ 2 [ 2 [ 2 [ 1 [ 1
% %|| 3 3 3 5 4
% %|| 7 ] 9 ] 8 ] 6 ] 6 ] }
% %|| .test_indices
% %|| { [ 1 [ 1 [ 1 [ 2 [ 2
% %|| 5 4 5 3 3
% %|| 6 ] 6 ] 6 ] 7 ] 9 ] }
% %
% % triple the number of partitions and sample from training indices
% q=cosmo_balance_partitions(p,ds,'nrepeats',3);
% cosmo_disp(q);
% %|| .train_indices
% %|| { [ 2 [ 2 [ 2 [ 1 [ 1 [ 1
% %|| 3 3 3 5 4 5
% %|| 7 ] 9 ] 8 ] 6 ] 6 ] 6 ] }
% %|| .test_indices
% %|| { [ 1 [ 1 [ 1 [ 2 [ 2 [ 2
% %|| 5 4 5 3 3 3
% %|| 6 ] 6 ] 6 ] 7 ] 9 ] 8 ] }
%
% Notes:
% - this function is intended for datasets where the number of
% samples across targets is not equally distributed. A typical
% application is MEEG datasets.
% - By default both the train and test indices are balanced, so that
% chance accuracy is equal to the inverse of the number of unique
% targets (1/C with C the number of classes).
% Balancing is considered a *Good Thing*:
% * Suppose the entire dataset has 75% samples of
% class A and 25% samples of class B, but the data does not contain
% any information that allows for discrimination between the classes.
% A classifier trained on a subset may always predict the class that
% occured most often in the training set, which is class A. If the test
% set also contains 75% of class A, then classification accuracy would
% be 75%, which is higher than 1/2 (with 2 the number of classes).
% * Balancing the training set only would accomodate this issue, but it
% may still be the case that a classifier consistently predicts one
% class more often than other classes. While this may be unbiased with
% respect to predictions of one particular class over many dataset
% instances, it could lead to biases (either above or below chance)
% in particular instances.
%
% See also: cosmo_nchoosek_partitioner, cosmo_nfold_partitioner
%
% # For CoSMoMVPA's copyright information and license terms, #
% # see the COPYING file distributed with CoSMoMVPA. #
defaults=struct();
defaults.seed=1;
defaults.balance_test=true;
params=cosmo_structjoin(defaults,varargin);
cosmo_check_partitions(partitions,ds,'unbalanced_partitions_ok',true);
classes=unique(ds.sa.targets);
nfolds_in=numel(partitions.train_indices);
train_indices_out=cell(1,nfolds_in);
test_indices_out=cell(1,nfolds_in);
for j=1:nfolds_in
tr_idx=partitions.train_indices{j};
te_idx=partitions.test_indices{j};
tr_targets=ds.sa.targets(tr_idx);
[tr_fold_classes,tr_fold_class_pos]=get_classes(tr_targets);
if ~isequal(tr_fold_classes,classes)
missing=setdiff(classes,tr_fold_classes);
error('missing training class %d in fold %d', missing(1), j);
end
% see how many output folds for the current input fold
nfolds_out=get_nfolds_out(tr_fold_class_pos,params);
train_indices_out{j}=sample_indices(tr_idx,tr_fold_class_pos,...
nfolds_out,params);
if params.balance_test
te_targets=ds.sa.targets(te_idx);
[te_fold_classes,te_fold_class_pos]=get_classes(te_targets);
if ~isequal(te_fold_classes,classes)
missing=setdiff(classes,te_fold_classes);
error('missing test class %d in fold %d', missing(1), j);
end
test_indices_out{j}=sample_indices(te_idx,te_fold_class_pos,...
nfolds_out,params);
else
test_indices_out{j}=repmat({te_idx},1,nfolds_out);
end
end
bal_partitions=struct();
bal_partitions.train_indices=cat(2,train_indices_out{:});
bal_partitions.test_indices=cat(2,test_indices_out{:});
bal_partitions=ensure_sorted_indices(bal_partitions);
cosmo_check_partitions(bal_partitions,ds);
function partitions=ensure_sorted_indices(partitions)
fns={'train_indices','test_indices'};
for k=1:numel(fns)
fn=fns{k};
idx_cell=partitions.(fn);
for j=1:numel(idx_cell)
idx=idx_cell{j};
if ~issorted(idx)
partitions.(fn){j}=sort(idx);
end
end
end
function tr_folds_out=sample_indices(target_idx,fold_class_pos,...
nfolds_out,params)
% sample from the indices
tr_folds_out_indices=sample_class_pos(fold_class_pos,...
nfolds_out,params);
% assing training indices
tr_folds_out=cell(1,nfolds_out);
for k=1:nfolds_out
tr_folds_out{k}=target_idx(tr_folds_out_indices{k});
end
function [classes,class_pos]=get_classes(targets)
[class_pos,targets_cell]=cosmo_index_unique({targets});
classes=targets_cell{1};
function nfolds=get_nfolds_out(class_pos,params)
% return how many folds are needed based on the sample indices for each
% class
if isfield(params,'nmin')
if isfield(params,'nrepeats')
error(['options ''nmin'' and nrepeats'' are '...
'mutually exclusive']);
else
targets_hist=cellfun(@numel,class_pos);
nsamples_ratio=max(targets_hist)/min(targets_hist);
nfolds=ceil(nsamples_ratio)*params.nmin;
end
elseif isfield(params,'nrepeats')
nfolds=params.nrepeats;
else
nfolds=1;
end
function folds=sample_class_pos(class_pos,nfolds,params)
% return nfolds folds, each with a sample from class_pos
nclasses=numel(class_pos);
class_count=cellfun(@numel,class_pos);
nsamples_per_class=min(class_count);
boundaries=[0;cumsum(class_count)];
nsamples=boundaries(end);
% single call to generate pseudo-random uniform data
uniform_random_all=cosmo_rand(nsamples,1,'seed',params.seed);
idxs=cell(nfolds,nclasses);
% process each fold seperately
for k=1:nclasses
uniform_random_pos=(boundaries(k)+1):boundaries(k+1);
[foo,i]=sort(uniform_random_all(uniform_random_pos));
nrepeats=ceil(nsamples_per_class*nfolds/numel(i));
% build sequence by repeating the random indices as many times as
% necessary
seq=repmat(i,1,nrepeats);
for j=1:nfolds
if k==1
idxs{j}=cell(1,nclasses);
end
seq_idx=nsamples_per_class*(j-1)+(1:nsamples_per_class);
idxs{j,k}=class_pos{k}(seq(seq_idx));
end
end
folds=cell(1,nfolds);
for j=1:nfolds
folds{j}=cat(1,idxs{j,:});
end