function test_suite = test_meeg_baseline_correct
% tests for cosmo_meeg_baseline_correct
%
% # For CoSMoMVPA's copyright information and license terms, #
% # see the COPYING file distributed with CoSMoMVPA. #
try % assignment of 'localfunctions' is necessary in Matlab >= 2016
test_functions = localfunctions();
catch % no problem; early Matlab versions can use initTestSuite fine
end
initTestSuite;
function test_meeg_baseline_correct_ft_comparison_methods()
methods = {'relative', 'absolute', 'relchange'};
references = {'manual', 'ft', 'ds'};
combis = cosmo_cartprod({methods, references});
ncombis = size(combis, 1);
for k = 1:ncombis
combi = combis(k, :);
method = combi{1};
reference = combi{2};
% random interval
start = rand(1) * -.25;
dur = .05 + rand(1) * .2;
interval = start + [0 dur];
helper_test_meeg_baseline_correct_comparison(interval, ...
method, reference);
end
function helper_test_meeg_baseline_correct_comparison(interval, method, ...
reference)
switch reference
case 'ft'
if cosmo_skip_test_if_no_external('fieldtrip')
return
end
if cosmo_wtf('is_octave')
cosmo_notify_test_skipped(['ft_freqbaseline is not '...
'compatible with Octave']);
end
case 'ds'
case 'manual'
otherwise
assert(false, 'this should not happen');
end
ds = cosmo_synthetic_dataset('type', 'timefreq', 'size', 'big', 'seed', 0);
ds.samples = randn(size(ds.samples));
chan_to_select = cosmo_randperm(max(ds.fa.chan), ceil(rand() * 3 + 1));
freq_to_select = cosmo_randperm(max(ds.fa.freq), ceil(rand() * 3 + 1));
m = cosmo_match(ds.fa.chan, chan_to_select) & ...
cosmo_match(ds.fa.freq, freq_to_select);
ds = cosmo_slice(ds, m, 2);
ds = cosmo_dim_prune(ds);
y = cosmo_meeg_baseline_correct(ds, interval, method);
% (unsupported in octave)
switch reference
case 'ft'
ft = cosmo_map2meeg(ds);
opt = struct();
opt.baseline = interval;
opt.baselinetype = method;
ft_bl = ft_freqbaseline(opt, ft);
x = cosmo_meeg_dataset(ft_bl);
case 'ds'
msk = cosmo_dim_match(ds, 'time', ...
@(t) t >= min(interval) & t <= max(interval));
ds_ref = cosmo_slice(ds, msk, 2);
x = cosmo_meeg_baseline_correct(ds, ds_ref, method);
case 'manual'
msk = cosmo_dim_match(ds, 'time', ...
@(t) t >= min(interval) & t <= max(interval));
ds_ref = cosmo_slice(ds, msk, 2);
x = ds;
for chan = 1:max(ds.fa.chan)
for freq = 1:max(ds.fa.freq)
msk = ds.fa.chan == chan & ds.fa.freq == freq;
s = ds.samples(:, msk);
ref_msk = ds_ref.fa.chan == chan & ds_ref.fa.freq == freq;
r = mean(ds_ref.samples(:, ref_msk), 2);
switch method
case 'absolute'
v = bsxfun(@minus, s, r);
case 'relative'
v = bsxfun(@rdivide, s, r);
case 'relchange'
v = bsxfun(@rdivide, bsxfun(@minus, s, r), r);
otherwise
error('not supported: %s', method);
end
x.samples(:, msk) = v;
end
end
otherwise
assert(false);
end
x_unq = cosmo_index_unique({x.fa.time, x.fa.chan});
y_unq = cosmo_index_unique({y.fa.time, y.fa.chan});
n = numel(x_unq);
max_n_to_choose = 10;
n_to_choose = min(max_n_to_choose, n);
rp = cosmo_randperm(n, n_to_choose);
assert(numel(x_unq) == numel(y_unq));
for j = 1:n_to_choose
idx = rp(j);
x_sel = cosmo_slice(x, x_unq{idx}, 2, false);
y_sel = cosmo_slice(y, y_unq{idx}, 2, false);
p = x_sel.samples;
q = y_sel.samples;
desc = sprintf('method %s, reference %s', method, reference);
assertElementsAlmostEqual(p, q, ...
'relative', 1e-6, desc);
assertEqual(x_sel.fa, y_sel.fa);
end
assertEqual(x.a.fdim, y.a.fdim);
function test_meeg_baseline_correct_regression()
interval = [-.15 -.04];
expected_samples = {[2.7634 -2.6366 2.3689 -0.36892 -0.43043; ...
-0.49007 -1.0933 0.79082 1.2092 0.1044], ...
[-0.63503 1.3096 -0.49296 0.49296 0.51511; ...
-2.5308 -3.5553 -0.35529 0.35529 -1.5211], ...
[1.7634 -3.6366 1.3689 -1.3689 -1.4304
-1.4901 -2.0933 -0.20918 0.20918 -0.8956]};
[ds, ds_ref] = get_test_dataset(interval);
methods = {'relative', 'absolute', 'relchange'};
ds_feature_msk = ds.fa.chan == 3 & ds.fa.freq == 2;
for k = 1:numel(methods)
method = methods{k};
for j = 1:2
if j == 1
ref = interval;
else
ref = ds_ref;
end
ds_bl = cosmo_meeg_baseline_correct(ds, ref, method);
ds_bl_msk = ds_bl.fa.chan == 3 & ds_bl.fa.freq == 2;
d = cosmo_slice(ds_bl, ds_bl_msk, 2);
assertElementsAlmostEqual(d.samples, expected_samples{k}, ...
'absolute', 1e-4);
d_fa = cosmo_slice(ds.fa, ds_feature_msk, 2, 'struct');
assertEqual(d.fa, d_fa);
end
end
function [ds, ds_ref] = get_test_dataset(interval)
ds = cosmo_synthetic_dataset('type', 'timefreq', 'size', 'big', ...
'senstype', 'neuromag306_planar', ...
'nchunks', 1);
ds.sa = struct();
ds.sa.rpt = (1:size(ds.samples, 1))';
msk = ds.fa.chan <= 4 & ds.fa.freq <= 2;
ds = cosmo_slice(ds, msk, 2);
ds = cosmo_dim_prune(ds);
matcher = @(x) interval(1) <= x & x <= interval(2);
ds_ref = cosmo_slice(ds, cosmo_dim_match(ds, 'time', matcher), 2);
ds_ref = cosmo_dim_prune(ds_ref);
function test_meeg_baseline_correct_nonmatching_sa
ds = cosmo_synthetic_dataset('size', 'big', 'ntargets', 8, ...
'nchunks', 1, 'type', 'timelock');
nsamples = size(ds.samples, 1);
while true
rp = cosmo_randperm(nsamples);
if ~isequal(rp, 1:nsamples)
break
end
end
ds_ref = cosmo_slice(ds, rp);
assertExceptionThrown(@()cosmo_meeg_baseline_correct( ...
ds, ds_ref, 'relative'), '');
function test_meeg_baseline_correct_nonmatching_fa
ds_big = cosmo_synthetic_dataset('size', 'big', 'ntargets', 8, ...
'nchunks', 1, 'type', 'timefreq');
ds = cosmo_slice(ds_big, ds_big.fa.chan <= 2 & ds_big.fa.freq <= 3, 2);
ds_ref = cosmo_slice(ds_big, ds_big.fa.chan <= 3 & ds_big.fa.freq <= 2, 2);
assertExceptionThrown(@()cosmo_meeg_baseline_correct( ...
ds, ds_ref, 'relative'), '');
function test_meeg_baseline_correct_illegal_inputs
bc = @cosmo_meeg_baseline_correct;
aet = @assertExceptionThrown;
ds = cosmo_synthetic_dataset('type', 'timefreq', 'size', 'tiny');
if cosmo_wtf('is_matlab')
id_missing_arg = {'MATLAB:inputArgUndefined', 'MATLAB:minrhs'};
else
id_missing_arg = 'Octave:undefined-function';
end
aet(@()bc(ds, ds), id_missing_arg);
aet(@()bc(ds, ds, 'foo'), '');
aet(@()bc(ds, cosmo_slice(ds, 1), 'relative'), '');
aet(@()bc(ds, cosmo_slice(ds, 1, 2), 'relative'), '');
% test slicing
bc(ds, cosmo_slice(ds, [1 2 3 4 5 6], 1), 'relative');
aet(@()bc(ds, cosmo_slice(ds, [1 2 4 6 4 3], 1), 'relative'), '');