function [balanced_ds, idxs, classes] = cosmo_balance_dataset(ds, varargin)
% sub-sample a dataset to have an equal number of samples for each target
%
% [balanced_ds,idxs,classes]=cosmo_balance_dataset(ds)
%
% Inputs:
% ds dataset struct with fields .samples,
% .sa.targets and .sa.chunks. All values in
% .sa.chunks must be different from each other.
% 'sample_balancer', f (optional)
% function handle with signature
% [idxs,classes]=f(targets,seed)
% where idxs is a SxC vector with indices for C
% classes and S targets per class. If omitted a
% builtin function is used.
% 'seed', s (optional, default=1)
% Seed to use for pseudo-random number generation
%
% Output:
% balanced_ds dataset with a subset of the samples from ds
% so that each target occurs equally often.
% Selection is (by default) done in a
% pseudo-determistic manner.
% idxs SxC vector indicating which
% classes Cx1 vector containing unique class labels
%
% Notes:
% - this function is to be used with MEEG datasets. it is not intended
% for fMRI data.
%
% # For CoSMoMVPA's copyright information and license terms, #
% # see the COPYING file distributed with CoSMoMVPA. #
defaults = struct();
defaults.seed = 1;
defaults.sample_balancer = [];
opt = cosmo_structjoin(defaults, varargin{:});
check_inputs(ds, opt);
[idxs, classes] = select_subset(ds.sa.targets, opt);
balanced_ds = cosmo_slice(ds, idxs(:), 1);
function [idxs, classes] = select_subset(targets, opt)
sample_balancer = opt.sample_balancer;
if isempty(sample_balancer)
sample_balancer = @default_sample_balancer;
end
[idxs, classes] = sample_balancer(targets, opt.seed);
function [idxs, classes] = default_sample_balancer(targets, seed)
[all_idxs, classes] = cosmo_index_unique(targets);
class_counts = cellfun(@numel, all_idxs);
min_count = min(class_counts);
max_count = max(class_counts);
nclasses = numel(class_counts);
% invoke PRNG only once
rand_vals = cosmo_rand(max_count, nclasses, 'seed', seed);
idxs = zeros(min_count, nclasses);
for k = 1:nclasses
% set class_idxs to have values in the range 1:class_counts in
% random order
[unused, class_idxs] = sort(rand_vals(1:class_counts(k), k));
% select min_count values from the indices in class_idxs
idxs(:, k) = all_idxs{k}(class_idxs(1:min_count));
end
function check_inputs(ds, opt)
cosmo_check_dataset(ds);
% chunks and targets must be present
raise_exception = true;
cosmo_isfield(ds, {'sa.targets', 'sa.chunks'}, raise_exception);
chunks = ds.sa.chunks;
if numel(sort(chunks)) ~= numel(unique(chunks))
error(['All values in .sa.chunks must be unique. If '...
'*and only if* all '...
'observations in .samples can be assumed to be '...
'independent, for a dataset ds you can set '...
' ds.sa.chunks(:)=1:numel(ds.sa.chunks,1)'...
'to indicate independence. This assumption typically '...
'only applies to M/EEG datasets; this function should '...
'not be used for typical fMRI datasets']);
end