function test_suite = test_balance_dataset
% tests for cosmo_average_samples
%
% # For CoSMoMVPA's copyright information and license terms, #
% # see the COPYING file distributed with CoSMoMVPA. #
try % assignment of 'localfunctions' is necessary in Matlab >= 2016
test_functions = localfunctions();
catch % no problem; early Matlab versions can use initTestSuite fine
end
initTestSuite;
function r = randint()
r = ceil(rand() * 10 + 10);
function test_balance_dataset_basics
nclasses = randint();
ds_cell = cell(nclasses, 1);
nreps = zeros(nclasses, 1);
classes = zeros(nclasses, 1);
class_id = 0;
for k = 1:nclasses
nrep = randint();
class_id = class_id + randint();
nreps(k) = nrep;
classes(k) = class_id;
ds_k = cosmo_synthetic_dataset('nchunks', 1, ...
'ntargets', 1, ...
'nreps', nrep, ...
'seed', 0);
ds_k.sa.targets(:) = class_id;
ds_cell{k} = ds_k;
end
ds = cosmo_stack(ds_cell);
nsamples = numel(ds.sa.chunks);
ds.sa.chunks(:) = 1:nsamples;
ds.sa.samples(:, 1) = 1:nsamples;
[unused, i] = sort(randn(nsamples, 1));
ds = cosmo_slice(ds, i);
[balanced_ds, idxs, balanced_classes] = cosmo_balance_dataset(ds);
assertEqual(unique(classes), balanced_classes);
assertEqual(cosmo_slice(ds, idxs(:), 1), balanced_ds);
assertEqual(size(idxs), [min(nreps), nclasses]);
for k = 1:nclasses
msk = balanced_ds.sa.targets == classes(k);
% correct number of selected samples
assertEqual(sum(msk), min(nreps));
% all selected samples are different
assertEqual(unique(balanced_ds.samples(msk, 1)), ...
sort(balanced_ds.samples(msk, 1)));
end
function test_balance_dataset_exceptions
aet = @(varargin)assertExceptionThrown(@() ...
cosmo_balance_dataset(varargin{:}), '');
ds = cosmo_synthetic_dataset('ntargets', randint());
nsamples = size(ds.samples, 1);
ds.sa.chunks(:) = 1:nsamples;
% this should be ok
cosmo_balance_dataset(ds);
% not a dataset
aet(struct);
% not all chunks unique raises an exception
bad_ds = ds;
bad_ds.sa.chunks(1) = bad_ds.sa.chunks(2);
aet(bad_ds);
% missing samples
bad_ds = rmfield(ds, 'samples');
aet(bad_ds);
% missing targets
bad_ds = ds;
bad_ds.sa = rmfield(bad_ds.sa, 'targets');
aet(bad_ds);
% missing chunks
bad_ds = ds;
bad_ds.sa = rmfield(bad_ds.sa, 'targets');
aet(bad_ds);