function test_suite = test_average_samples
% tests for cosmo_average_samples
%
% # For CoSMoMVPA's copyright information and license terms, #
% # see the COPYING file distributed with CoSMoMVPA. #
try % assignment of 'localfunctions' is necessary in Matlab >= 2016
test_functions = localfunctions();
catch % no problem; early Matlab versions can use initTestSuite fine
end
initTestSuite;
function test_average_samples_
ds = cosmo_synthetic_dataset();
a = cosmo_average_samples(ds);
assertElementsAlmostEqual(sort(a.samples), sort(ds.samples));
assertElementsAlmostEqual(sort(a.samples(:, 3)), sort(ds.samples(:, 3)));
a = cosmo_average_samples(ds, 'ratio', .5);
assertElementsAlmostEqual(sort(a.samples), sort(ds.samples));
assertElementsAlmostEqual(sort(a.samples(:, 3)), sort(ds.samples(:, 3)));
% check wrong inputs
aet = @(varargin)assertExceptionThrown(@() ...
cosmo_average_samples(varargin{:}), '');
aet(ds, 'ratio', .1);
aet(ds, 'ratio', 3);
aet(ds, 'ratio', .5, 'count', 2);
ds.sa.chunks(:) = 1;
a = cosmo_average_samples(ds, 'ratio', .5);
cosmo_check_dataset(a);
ds = cosmo_slice(ds, 3, 2);
ns = size(ds.samples, 1);
ds.samples = ds.sa.targets * 1000 + (1:ns)';
a = cosmo_average_samples(ds, 'ratio', .5, 'nrep', 10);
% no mixing of different targets
delta = a.samples / 1000 - a.sa.targets;
assertTrue(all(.00099 <= delta & delta < .05));
assertElementsAlmostEqual(delta * 3000, round(delta * 3000));
a = cosmo_average_samples(ds, 'count', 3, 'nrep', 10);
% no mixing of different targets
delta = a.samples / 1000 - a.sa.targets;
assertTrue(all(.00099 <= delta & delta < .05));
assertElementsAlmostEqual(delta * 3000, round(delta * 3000));
function test_average_samples_split_by
plural_singular = {'targets', 'targets'; ...
'chunks', 'chunks'; ...
'subjects', 'subject'; ...
'modalities', 'modality' ...
};
n_dim = size(plural_singular, 1);
combis = cosmo_cartprod(repmat({{true, false}}, n_dim, 1)');
for k = 1:size(combis, 1)
combi = cell2mat(combis(k, :));
opt = struct();
opt.seed = 0; % truly random data
for j = 1:n_dim
count = ceil(rand() * 2 + 1);
opt.(['n' plural_singular{j, 1}]) = count;
end
ds = cosmo_synthetic_dataset(opt);
values = cell(n_dim, 1);
for j = 1:n_dim
if combi(j)
values{j} = ds.sa.(plural_singular{j, 2});
end
end
values = values(combi);
if any(combi)
[idx, unq_cell] = cosmo_index_unique(values);
else
idx = {1:(size(ds.samples, 1))};
end
n_avg = numel(idx);
n_features = size(ds.samples, 2);
expected_samples = zeros(n_avg, n_features);
for m = 1:n_avg
expected_samples(m, :) = mean(ds.samples(idx{m}, :), 1);
end
result = cosmo_average_samples(ds, ...
'split_by', plural_singular(combi, 2));
assertEqual(size(result.samples), size(expected_samples));
delta = bsxfun(@minus, result.samples(:, 1), expected_samples(:, 1)');
mapping = zeros(1, n_avg);
for m = 1:n_avg
[mn, mn_idx] = min(abs(delta(m, :)));
assert(mn < 1e-5); % deal with rounding
mapping(mn_idx) = m;
end
assertEqual(sort(mapping), 1:n_avg);
result_perm = cosmo_slice(result, mapping);
assertElementsAlmostEqual(result_perm.samples, expected_samples);
pos = 0;
for j = 1:n_dim
if combi(j)
pos = pos + 1;
fn = plural_singular{j, 2};
assertEqual(unq_cell{pos}, result_perm.sa.(fn));
end
end
% check default result
if isequal(plural_singular(combi), {'targets', 'chunks'})
default_result = cosmo_average_samples(ds);
assertEqual(result, default_result);
end
end
function test_average_samples_split_by_empty()
ds = cosmo_synthetic_dataset('ntargets', ceil(rand() * 5 + 2), ...
'nchunks', ceil(rand() * 5 + 2));
result = cosmo_average_samples(ds, 'split_by', {});
assertElementsAlmostEqual(result.samples, mean(ds.samples, 1));
function test_average_samples_exceptions
aet = @(varargin)assertExceptionThrown(@() ...
cosmo_average_samples(varargin{:}), '');
ds = cosmo_synthetic_dataset('nreps', 5);
aet([]);
x = struct();
x.samples = randn(4);
aet(x);
% illegal count
aet(ds, 'count', 6);
aet(ds, 'count', [2 2]);
aet(ds, 'count', 3.5);
aet(ds, 'count', 0);
% illegal ratio
aet(ds, 'ratio', 1.2);
aet(ds, 'ratio', -0.2);
aet(ds, 'ratio', [.5 .5]);
% mutually exclusive
aet(ds, 'ratio', .5, 'count', 2);
aet(ds, 'repeats', [2 2]);
aet(ds, 'repeats', -1);
aet(ds, 'resamplings', [2 2]);
aet(ds, 'resamplings', -1);
aet(ds, 'resamplings', 1, 'repeats', 1);
% not existing field
ds_bad = ds;
ds_bad.sa = rmfield(ds_bad.sa, 'targets');
aet(ds_bad);
% illegal split-by arguments
aet(ds, 'split_by', []);
aet(ds, 'split_by', struct());
aet(ds, 'split_by', 'foo');
aet(ds, 'split_by', {1, 2});
function test_average_samples_with_repeats
nchunks = ceil(rand() * 4 + 3);
ntargets = ceil(rand() * 4 + 3);
ncombi_max = ceil(rand() * 3 + 4);
max_cyc = 5;
ncombi_min = ceil(ncombi_max / 2);
ds = cosmo_synthetic_dataset('nchunks', nchunks, ...
'ntargets', ntargets, ...
'nreps', ncombi_max);
ds.sa = rmfield(ds.sa, 'rep');
sp = cosmo_split(ds, {'targets', 'chunks'});
n_splits = numel(sp);
% select subset of samples, each with at least ncombi_min repeats
combi_count = zeros(nchunks, ntargets);
for k = 1:n_splits
if k == 1
% ensure at least one with minimum
nkeep = ncombi_min;
else
nkeep = ncombi_min + floor(rand() * (ncombi_max - ncombi_min));
end
ds_k = cosmo_slice(sp{k}, 1:nkeep);
ds_k.sa.repeats = (1:nkeep)';
combi_count(ds_k.sa.chunks(1), ds_k.sa.targets(1)) = nkeep;
sp{k} = ds_k;
end
assert(all(cellfun(@(x)size(x.samples, 1), sp)));
ds = cosmo_stack(sp);
[nsamples, nfeatures] = size(ds.samples);
% bit widths for features, chunks, targets, and repeats
bws = [nfeatures, nchunks, ntargets, ceil(log2(max_cyc + 1)) + ncombi_max];
% encode features, chunks, targets and repeats into single number
dsb = binarize_ds(ds, bws);
% helper function
check_with = @(args, ...
count, ...
repeats) check_with_helper(dsb, args, count, repeats, ...
nchunks, ntargets, ...
ncombi_max, combi_count, ...
bws);
for repeats = [1, ceil(rand() * ncombi_max)]
for count = [1, ceil(rand() * ncombi_min)]
check_with({'count', count, 'repeats', repeats}, ...
count, repeats);
end
for ratio = [.5, .3 + rand() * .7]
count = round(ratio * min(combi_count(:)));
check_with({'ratio', ratio, 'repeats', repeats}, ...
count, repeats);
end
end
for resamplings = [0, 1, 2 + round(rand() * 4)]
count = ceil(rand() * ncombi_min);
if resamplings == 0
repeats = floor(ncombi_min / count);
args = {'count', count};
else
repeats = floor(resamplings * ncombi_min / count);
args = {'count', count, 'resamplings', resamplings};
end
check_with(args, count, repeats);
end
function check_with_helper(dsb, args, count, repeats, ...
nchunks, ntargets, ncombi_max, combi_count, bws)
mu = cosmo_average_samples(dsb, args{:});
[chunks, targets, ids] = unbinarize_ds(mu, bws, count);
nsamples = size(ids, 1);
nfeatures = size(dsb.samples, 2);
% chunk, target, repeat count
ctr_count = zeros(nchunks, ntargets, ncombi_max);
% keep track of each target and chunk combination
for j = 1:nsamples
for k = 1:nfeatures
% select same samples for all features
id = ids{j, k};
if k == 1
first_id = id;
else
assertEqual(first_id, id);
end
end
% no repeats
id_sorted = sort(id(:));
assert(all(diff(id_sorted) > 0));
% count should match
assertEqual(numel(id), count);
ctr_count(chunks(j), targets(j), id) = ...
ctr_count(chunks(j), targets(j), id) + 1;
end
% ensure each sample selected about equally often
[nchunks, ntargets] = size(combi_count);
for k = 1:nchunks
for j = 1:ntargets
c = squeeze(ctr_count(k, j, :));
pre = c(1:combi_count(k, j));
assert(max(pre) - min(pre) <= 1);
post = c((combi_count(k, j) + 1):end);
assert(all(post == 0));
end
end
% check each target and chunk combination was used the correct number
% of times to form the average
ct_count = sum(ctr_count, 3);
expected_ct_count = count * repeats * ones(nchunks, ntargets);
assert(isequal(ct_count, expected_ct_count));
function [chunks, targets, ids] = unbinarize_ds(ds, bws, counts)
[nsamples, nfeatures] = size(ds.samples);
ids = cell(nsamples, nfeatures);
chunks = zeros(nsamples, 1);
targets = zeros(nsamples, 1);
for k = 1:nsamples
for j = 1:nfeatures
% Decode repeats; multiple repeats can be present.
% As there can be multiple repeats, the averaging is undone
% and then each bit represents just one repeat
v_id = quick_dec2bin(mod(ds.samples(k, j) * counts, ...
2^bws(end)), ...
bws(end));
ids{k, j} = bws(end) - find(v_id) + 1;
% decode chunks, targets, ids
v = decode(floor(ds.samples(k, j) / 2^bws(end)), bws(1:(end - 1)));
assertEqual(log2(v(1)) + 1, j);
c = log2(v(2)) + 1;
t = log2(v(3)) + 1;
if j == 1
chunks(k) = c;
targets(k) = t;
else
assertEqual(c, chunks(k));
assertEqual(t, targets(k));
end
end
end
function bds = binarize_ds(ds, bws)
bds = ds;
[nsamples, nfeatures] = size(ds.samples);
for k = 1:nsamples
sa = cosmo_slice(ds.sa, k, 1, 'struct');
for j = 1:nfeatures
vs = [j, sa.chunks sa.targets sa.repeats];
bds.samples(k, j) = encode(vs, bws);
end
end
function p = encode(vs, bws)
% encode several decimal numbers in a single one, through
% encode([X1 ... Xn]) = bin2dec([dec2bin(X1) ... dec2bin(Xn)])
% where bws contains the bit width for each number
n = numel(bws);
assert(numel(vs) == n);
bs = cell(1, n);
for k = 1:n
bw = bws(k);
bs{k} = zeros(1, bw);
bs{k}(bw - vs(k) + 1) = 1;
end
p = quick_bin2dec(cat(2, bs{:}));
function vs = decode(p, bws)
% encode single decimal numbers in multiple ones, through
% decode(P) = [bin2dec(PB1) ... bin2dec(PBn)]
% with PBi the binary representation part of P for each binary
% representation part
arr = quick_dec2bin(p, sum(bws));
c = 0;
n = numel(bws);
vs = zeros(1, n);
for k = 1:n
offset = bws(k);
vs(k) = quick_bin2dec(arr(c + (1:offset)));
c = c + offset;
end
function arr = quick_dec2bin(x, bw)
% converts decimal number x to array with length bw and all
% values in 0 and 1
assert(round(x) == x);
arr = zeros(1, bw);
xbs = dec2bin(x);
arr(bw - numel(xbs) + 1:end) = (xbs == '1');
return
function x = quick_bin2dec(arr)
% convert binary array to decimal number
x = sum(2.^((numel(arr) - 1):-1:0) .* arr);