function test_suite = test_correlation_measure
% tests for cosmo_correlation_measure
%
% # For CoSMoMVPA's copyright information and license terms, #
% # see the COPYING file distributed with CoSMoMVPA. #
try % assignment of 'localfunctions' is necessary in Matlab >= 2016
test_functions = localfunctions();
catch % no problem; early Matlab versions can use initTestSuite fine
end
initTestSuite;
function test_correlation_measure_basis()
% This test requires statistics functions
cosmo_skip_test_if_no_external('#stats');
ds3 = cosmo_synthetic_dataset('nchunks', 3, 'ntargets', 4);
ds = cosmo_slice(ds3, ds3.sa.chunks <= 2);
ds.sa.chunks = ds.sa.chunks + 10;
ds.sa.targets = ds.sa.targets + 20;
x = ds.samples(ds.sa.chunks == 11, :);
y = ds.samples(ds.sa.chunks == 12, :);
cxy = atanh(corr(x', y'));
diag_msk = eye(4) > 0;
c_diag = mean(cxy(diag_msk));
c_off_diag = mean(cxy(~diag_msk));
delta = c_diag - c_off_diag;
c1 = cosmo_correlation_measure(ds);
assertElementsAlmostEqual(delta, c1.samples, 'relative', 1e-5);
assertEqual(c1.sa.labels, {'corr'});
% reset state and do not show warnings
orig_warning_state = cosmo_warning();
warning_cleaner = onCleanup(@()cosmo_warning(orig_warning_state));
cosmo_warning('reset');
cosmo_warning('off');
c2 = cosmo_correlation_measure(ds, 'output', 'correlation');
assertElementsAlmostEqual(reshape(cxy, [], 1), c2.samples);
assertEqual(kron((1:4)', ones(4, 1)), c2.sa.half2);
assertEqual(repmat((1:4)', 4, 1), c2.sa.half1);
i = 7;
assertElementsAlmostEqual(cxy(c2.sa.half1(i), c2.sa.half2(i)), ...
c2.samples(i));
assertEqual({'half1', 'half2'}, c2.a.sdim.labels);
assertEqual({20 + (1:4)', 20 + (1:4)'}, c2.a.sdim.values);
c4 = cosmo_correlation_measure(ds3, 'output', 'mean_by_fold');
%
for j = 1:3
train_idxs = (3 - j) * 4 + (1:4);
test_idxs = setdiff(1:12, train_idxs);
ds_sel = ds3;
ds_sel.sa.chunks(train_idxs) = 2;
ds_sel.sa.chunks(test_idxs) = 1;
c5 = cosmo_correlation_measure(ds_sel, 'output', 'mean');
assertElementsAlmostEqual(c5.samples, c4.samples(j));
end
% test permutations
ds4 = cosmo_synthetic_dataset('nchunks', 2, 'ntargets', 10);
rp = randperm(20);
ds4_perm = cosmo_slice(ds4, rp);
assertEqual(cosmo_correlation_measure(ds4), ...
cosmo_correlation_measure(ds4_perm));
opt = struct();
opt.output = 'correlation';
assertEqual(cosmo_correlation_measure(ds4, opt), ...
cosmo_correlation_measure(ds4_perm, opt));
function test_correlation_measure_single_target
% This test requires statistics functions
cosmo_skip_test_if_no_external('#stats');
for ntargets = 2:6
ds = cosmo_synthetic_dataset('nchunks', 2, 'ntargets', ntargets);
ds.samples = randn(size(ds.samples));
ds.sa.targets(:) = 1;
idxs = cosmo_index_unique(mod(ds.sa.chunks, 2));
assert(numel(idxs) == 2);
x = mean(ds.samples(idxs{1}, :));
y = mean(ds.samples(idxs{2}, :));
r_xy = atanh(corr(x', y'));
r_ds = cosmo_correlation_measure(ds, 'template', 1);
assertElementsAlmostEqual(r_xy, r_ds.samples);
end
function test_correlation_measure_regression()
helper_test_correlation_measure_regression(false);
function test_correlation_measure_regression_spearman()
if cosmo_skip_test_if_no_external('@stats')
return
end
helper_test_correlation_measure_regression(true);
function helper_test_correlation_measure_regression(test_spearman)
% reset state and do not show warnings
orig_warning_state = cosmo_warning();
warning_cleaner = onCleanup(@()cosmo_warning(orig_warning_state));
cosmo_warning('reset');
cosmo_warning('off');
ds = cosmo_synthetic_dataset('ntargets', 3, 'nchunks', 5, 'sigma', .5);
params = get_regression_test_params(test_spearman);
n_params = numel(params);
for k = 1:n_params
param = params{k};
args = param{1};
samples = param{2};
sa = param{3};
sdim = param{4};
res = cosmo_correlation_measure(ds, args{:});
% test samples
assertElementsAlmostEqual(res.samples, samples', 'absolute', 5e-3);
% test sa
keys = fieldnames(res.sa);
assertEqual(sort(keys(:)), sort(sa(1:2:end))');
for j = 1:2:numel(sa)
key = sa{j};
value = sa{j + 1};
assertEqual(res.sa.(key), value(:));
end
% test sa
if isempty(sdim)
assertFalse(isfield(res, 'a'));
else
keys = fieldnames(res.a.sdim);
assertEqual(sort(keys(:)), sort(sdim(1:2:end))');
for j = 1:2:numel(sa)
key = sdim{j};
value = sdim{j + 1};
sdim_value = res.a.sdim.(key);
assertEqual(sdim_value(:), value(:));
end
end
end
function params = get_regression_test_params(test_spearman)
% contents
% 1) input arguments
% 2) samples
% 3) sample attributes
% 4) sdim
if test_spearman
params = {{{ 'corr_type' 'Spearman' }, ...
-0.228, ...
{ 'labels' { 'corr' } }, ...
[]}};
else
params = {{{ }, ...
-0.24, ...
{ 'labels' { 'corr' } }, ...
[]}, ...
{{ 'template' [-2 2 3; -1 1 2; 2 -4 -3] }, ...
2.48, ...
{ 'labels' { 'corr' } }, ...
[]}, ...
{{ 'merge_func' @(x)sum(abs(x), 1) }, ...
0.567, ...
{ 'labels' { 'corr' } }, ...
[]}, ...
{{ 'post_corr_func' @(x)x + 1 }, ...
-0.204, ...
{ 'labels' { 'corr' } }, ...
[]}, ...
{{ 'output' 'mean_by_fold' }, ...
[-0.289 -0.274 -0.532 -0.112 -0.269 ...
-0.0535 -0.203 -0.3 -0.198 -0.173], ...
{ 'partition' [1 2 3 4 5 6 7 8 9 10] }, ...
[]}, ...
{{ 'output' 'correlation' }, ...
[-0.649 0.0675 0.345 0.126 0.643 ...
0.413 0.266 0.399 0.0933], ...
{ 'half1', [1 2 3 1 2 3 1 2 3], ...
'half2' [1 1 1 2 2 2 3 3 3] }, ...
{ 'labels' { 'half1' 'half2' }, ...
'values' { [1 2 3]' [1 2 3]' } }
}};
end
function test_correlation_measure_exceptions
aet = @(varargin)assertExceptionThrown(@() ...
cosmo_correlation_measure(varargin{:}), '');
% reset state and do not show warnings
orig_warning_state = cosmo_warning();
warning_cleaner = onCleanup(@()cosmo_warning(orig_warning_state));
cosmo_warning('reset');
cosmo_warning('off');
ds = cosmo_synthetic_dataset('nchunks', 2);
aet(ds, 'template', eye(4));
aet(ds, 'output', 'foo');
aet(ds, 'output', 'one_minus_correlation');
% single target throws exception
ds.sa.targets(:) = 1;
aet(ds);
aet(ds, 'template', 2);
aet(ds, 'template', eye(2));
ds.sa.targets(1) = 2;
aet(ds);
function x = identity(x)
function test_correlation_measure_warning_shown_if_no_defaults()
orig_warning_state = cosmo_warning();
cleaner = onCleanup(@()cosmo_warning(orig_warning_state));
% reset state and do not show warnings
cosmo_warning('reset');
cosmo_warning('off');
funcs = {[], @atanh};
outputs = {[], 'mean', 'raw', 'correlation'};
for k = 1:numel(funcs)
func = funcs{k};
for j = 1:numel(outputs)
output = outputs{j};
is_default_func = k <= 2;
is_default_output = j <= 2;
expect_warning = ~(is_default_func && is_default_output);
opt = struct();
if ~isempty(func)
opt.post_corr_func = func;
end
if ~isempty(output)
opt.output = output;
end
cosmo_warning('reset');
cosmo_warning('off');
ds = cosmo_synthetic_dataset('nchunks', 2);
cosmo_correlation_measure(ds, opt);
s = cosmo_warning();
showed_warning = numel(s.shown_warnings) > 0;
assertEqual(expect_warning, showed_warning);
end
end
function test_correlation_measure_wrong_template_size()
ds = cosmo_synthetic_dataset('nchunks', 2, 'ntargets', 2);
measure_args = struct();
measure_args.template = [1 -1 0 0; ...
-1 1 0 0; ...
0 0 1 -1
0 0 -1 1];
measure = @cosmo_correlation_measure;
assertExceptionThrown(@()measure(ds, measure_args));