function is_ok = cosmo_check_partitions(partitions, ds, varargin)
% check whether partitions are balanced and not double-dippy
%
% cosmo_check_partitions(partitions, ds, varargin)
%
% Inputs:
% partitions struct with partitions, e.g. from
% cosmo_{nfold,oddeven,nchoosek}_partitioner
% ds dataset struct with fields .sa.{targets,chunks}
% opt (optional) struct with optional field:
% .unbalanced_partitions_ok if set to true, then unbalanced
% partitions (with a different number of
% targets of each class in a chunk) is ok
% Output:
% is_ok boolean, true if partitions are ok
%
% Throws:
% - an error if partitions are double dippy or (unless specified in opt)
% not balanced
%
% Examples:
% ds=struct();
% ds.samples=zeros(9,2);
% ds.sa.targets=[1 1 2 2 2 3 3 3 3]';
% ds.sa.chunks=[1 2 2 1 1 1 2 2 2]';
% % make unbalanced partitions
% partitions=cosmo_nfold_partitioner(ds);
% cosmo_check_partitions(partitions,ds);
% %|| error('Unbalance in partition 1 [...]')
% %
% % disable unbalanced check
% cosmo_check_partitions(partitions,ds,'unbalanced_partitions_ok',true)
% %|| true
% %
% % balance partitions and check without unbalanced check
% partitions=cosmo_balance_partitions(partitions,ds);
% cosmo_check_partitions(partitions,ds,'unbalanced_partitions_ok',false)
% %|| true
% %
% % make the partitions double dippy
% partitions.train_indices{1}=partitions.test_indices{1};
% cosmo_check_partitions(partitions,ds,'unbalanced_partitions_ok',true)
% %|| error('double dipping in fold 1: chunk 1 is in train and test set')
% %
% % make partitions empty
% partitions.train_indices{1}=[];
% cosmo_check_partitions(partitions,ds);
% %|| error('partition 1: .train_indices are empty')
% %
% % partitions have values outside range
% partitions.train_indices{1}=100;
% cosmo_check_partitions(partitions,ds);
% %|| error('partition 1: .train_indices are outside range 1..9');
% %
% % use non-integers
% partitions.train_indices{1}=1.5;
% cosmo_check_partitions(partitions,ds);
% %|| error('partition 1: .train_indices are not integers');
%
% Notes:
% - the reason to require balancing by default is that chance level is
% 1/nclasses, which is useful for many subsequent analyses.
% - if this function raises an exception for partitions, consider running
% partitions=cosmo_balance_partitions(partitions,...).
%
% See also: cosmo_balance_partitions, cosmo_nfold_partitioner
%
% # For CoSMoMVPA's copyright information and license terms, #
% # see the COPYING file distributed with CoSMoMVPA. #
% process input arguments
defaults = struct();
defaults.unbalanced_partitions_ok = false;
params = cosmo_structjoin(defaults, varargin{:});
% whether to check for equal number of samples of each class in
% each chunks
check_balance = ~params.unbalanced_partitions_ok;
% whether to check for the same chunk in train and test set
check_double_dipping = true;
% check dataset
check_dataset(ds);
% ensure it has targets and chunks
cosmo_isfield(ds, {'sa.targets', 'sa.chunks'}, true);
targets = ds.sa.targets;
chunks = ds.sa.chunks;
if check_balance
[classes, unused, sample2class] = unique(targets);
end
% must have .train_indices and .test_indices
cosmo_isfield(partitions, {'train_indices', 'test_indices'}, true);
train_indices = partitions.train_indices;
test_indices = partitions.test_indices;
if ~iscell(train_indices) || ~iscell(test_indices)
error('.train_indices and .test_indices must be a cell');
end
% ensure equal number of partitions for train and test
npartitions = numel(train_indices);
if npartitions ~= numel(test_indices)
error('Partition count mismatch for train and test: %d ~= %d', ...
npartitions, numel(test_indices));
end
nsamples = numel(targets);
unsorted_train_test_fold = false(npartitions, 2);
for k = 1:npartitions
train_idxs = train_indices{k};
test_idxs = test_indices{k};
check_range(train_idxs, nsamples, k, 'train');
check_range(test_idxs, nsamples, k, 'test');
if ~issorted(train_idxs)
unsorted_train_test_fold(k, 1) = true;
end
if ~issorted(test_idxs)
unsorted_train_test_fold(k, 2) = true;
end
if check_balance
% counts of number of samples in each each class must be the
% same
train_classes = sample2class(train_idxs);
h = histc(train_classes, 1:numel(classes));
h_nonzero_idxs = find(h > 0);
if k == 1
first_h_nonzero_idxs = h_nonzero_idxs;
elseif ~isequal(first_h_nonzero_idxs, h_nonzero_idxs)
error(['Different targets used for training '...
'in partition %d and %d. This is weird. '...
'Consider the following scenarios:\n'...
'(1) You made the partitions manually. It is '...
'possible that you made a mistake.\n'...
'(2) You try to do cross-decoding and '...
'partitions were defined using '...
'cosmo_nchoosek_partitioner. This usually '...
'requires a *re-assignment* of .sa.targets '...
'so that the unique targets values are the '...
'same for the samples used for training '...
'and for testing. Please read the '...
'documentation of cosmo_nchoosek_partitioner '...
'(especially the examples) carefully and '...
'verify that '...
'the .sa.targets are (re)assigned properly.\n'...
'(3) You used a cosmo_ function to set the '...
'partitions, but case 2 does not apply. '...
'It may indicate a bug; in that case, '...
'please get in touch with the CoSMoMVPA '...
'developers'], 1, k);
end
h_nonzero = h(h_nonzero_idxs);
if ~all(h_nonzero(1) == h_nonzero)
idx = find(h_nonzero(1) ~= h_nonzero, 1);
pos_first = h_nonzero_idxs(1);
pos_other = h_nonzero_idxs(idx);
error(['Unbalance in partition %d, '...
'classes %d (#=%d) and %d (#=%d). '...
'Consider the following scenarios:\n'...
'(1) the input is an MEEG dataset, or an fMRI '...
'dataset with missing samples (trials) in some '...
'of the runs. You probably want to use '...
'cosmo_balance_partitions.\n'...
'(2) all chunks are unique; this is either a (a) '...
'between-participants design and you '...
'try to discriminate between participants (e.g. '...
'patients versus controls); or this is an MEEG '...
'dataset where all samples are assumed to be '...
'independent. You probably want to use '...
'cosmo_independent_samples_partitioner to define '...
'the partitions\n'...
'(3) the input is an fMRI dataset using beta '...
'estimates or t-statistics estimated using the '...
'GLM, with one sample '...
'of each condition per run (e.g. nfold partition) '...
'or per set of runs (e.g. odd-even partition). '...
'Probably a mistake was made setting .sa.chunks '...
'or .sa.targets\n'...
'(4) you *really* know what you are doing: '...
'as a litmus test, you would be comfortable '...
'implementing a bootstrapping algorithm to '...
'estimate the cdf of your measure of interest '...
'under some null hypothesis. You can set '...
'unbalanced_partitions_ok to true as an option'], ...
k, classes(pos_first), h(pos_first), ...
classes(pos_other), h(pos_other));
end
end
if check_double_dipping
% no sample allowed to be both in train and test indices
if any(cosmo_match(chunks(train_idxs), chunks(test_idxs)))
ctrain = chunks(train_idxs);
m = cosmo_match(ctrain, chunks(test_idxs));
idx = find(m, 1);
error(['double dipping in fold %d: chunk %d is in '...
'train and test set'], k, ctrain(idx));
end
end
end
has_unsorted_indices = any(unsorted_train_test_fold(:));
if has_unsorted_indices
fns = {'train_indices', 'test_indices'};
msg_parts = {'', '', ['Unsorted partitions can lead to unintuitive '...
'ordering of results (e.g. when using fold_predictions)']};
for k = 1:2
fn = fns{k};
folds = find(unsorted_train_test_fold(:, k));
if ~isempty(folds)
msg_parts{k} = sprintf('Unsorted %s in fold(s):%s', ...
fn, sprintf(' %d', folds));
end
end
msk = ~cellfun(@isempty, msg_parts);
msg = cosmo_strjoin(msg_parts(msk), '\n');
cosmo_warning(msg);
end
is_ok = true;
function check_range(idxs, nsamples, partition, label)
msg = '';
if isempty(idxs)
msg = 'empty';
elseif ~isequal(idxs, round(idxs))
msg = 'not integers';
elseif min(idxs) < 1 || max(idxs) > nsamples
msg = sprintf('outside range 1:%d', nsamples);
end
if ~isempty(msg)
error('partition %d: .%s_indices are %s', partition, label, msg);
end
function check_dataset(ds)
persistent cached_sa
if isstruct(ds) && isfield(ds, 'sa')
if ~isequal(ds.sa, cached_sa)
cosmo_check_dataset(ds);
cached_sa = ds.sa;
end
return
end
error('second input must be a dataset struct with field .sa');