test meeg baseline correct

function test_suite = test_meeg_baseline_correct
    % tests for cosmo_meeg_baseline_correct
    %
    % #   For CoSMoMVPA's copyright information and license terms,   #
    % #   see the COPYING file distributed with CoSMoMVPA.           #
    try % assignment of 'localfunctions' is necessary in Matlab >= 2016
        test_functions = localfunctions();
    catch % no problem; early Matlab versions can use initTestSuite fine
    end
    initTestSuite;

function test_meeg_baseline_correct_ft_comparison_methods()
    methods = {'relative', 'absolute', 'relchange'};
    references = {'manual', 'ft', 'ds'};

    combis = cosmo_cartprod({methods, references});
    ncombis = size(combis, 1);

    for k = 1:ncombis
        combi = combis(k, :);
        method = combi{1};
        reference = combi{2};

        % random interval
        start = rand(1) * -.25;
        dur = .05 + rand(1) * .2;
        interval = start + [0 dur];

        helper_test_meeg_baseline_correct_comparison(interval, ...
                                                     method, reference);
    end

function helper_test_meeg_baseline_correct_comparison(interval, method, ...
                                                      reference)

    switch reference
        case 'ft'
            if cosmo_skip_test_if_no_external('fieldtrip')
                return
            end

            if cosmo_wtf('is_octave')
                cosmo_notify_test_skipped(['ft_freqbaseline is not '...
                                           'compatible with Octave']);
            end

        case 'ds'

        case 'manual'

        otherwise
            assert(false, 'this should not happen');
    end

    ds = cosmo_synthetic_dataset('type', 'timefreq', 'size', 'big', 'seed', 0);
    ds.samples = randn(size(ds.samples));

    chan_to_select = cosmo_randperm(max(ds.fa.chan), ceil(rand() * 3 + 1));
    freq_to_select = cosmo_randperm(max(ds.fa.freq), ceil(rand() * 3 + 1));

    m = cosmo_match(ds.fa.chan, chan_to_select) & ...
            cosmo_match(ds.fa.freq, freq_to_select);
    ds = cosmo_slice(ds, m, 2);
    ds = cosmo_dim_prune(ds);

    y = cosmo_meeg_baseline_correct(ds, interval, method);

    % (unsupported in octave)
    switch reference
        case 'ft'
            ft = cosmo_map2meeg(ds);

            opt = struct();
            opt.baseline = interval;
            opt.baselinetype = method;

            ft_bl = ft_freqbaseline(opt, ft);
            x = cosmo_meeg_dataset(ft_bl);

        case 'ds'
            msk = cosmo_dim_match(ds, 'time', ...
                                  @(t) t >= min(interval) & t <= max(interval));
            ds_ref = cosmo_slice(ds, msk, 2);
            x = cosmo_meeg_baseline_correct(ds, ds_ref, method);

        case 'manual'
            msk = cosmo_dim_match(ds, 'time', ...
                                  @(t) t >= min(interval) & t <= max(interval));
            ds_ref = cosmo_slice(ds, msk, 2);

            x = ds;

            for chan = 1:max(ds.fa.chan)
                for freq = 1:max(ds.fa.freq)
                    msk = ds.fa.chan == chan & ds.fa.freq == freq;
                    s = ds.samples(:, msk);

                    ref_msk = ds_ref.fa.chan == chan & ds_ref.fa.freq == freq;
                    r = mean(ds_ref.samples(:, ref_msk), 2);

                    switch method
                        case 'absolute'
                            v = bsxfun(@minus, s, r);

                        case 'relative'
                            v = bsxfun(@rdivide, s, r);

                        case 'relchange'
                            v = bsxfun(@rdivide, bsxfun(@minus, s, r), r);

                        otherwise
                            error('not supported: %s', method);
                    end

                    x.samples(:, msk) = v;
                end
            end

        otherwise
            assert(false);

    end

    x_unq = cosmo_index_unique({x.fa.time, x.fa.chan});
    y_unq = cosmo_index_unique({y.fa.time, y.fa.chan});

    n = numel(x_unq);
    max_n_to_choose = 10;
    n_to_choose = min(max_n_to_choose, n);
    rp = cosmo_randperm(n, n_to_choose);
    assert(numel(x_unq) == numel(y_unq));
    for j = 1:n_to_choose
        idx = rp(j);
        x_sel = cosmo_slice(x, x_unq{idx}, 2, false);
        y_sel = cosmo_slice(y, y_unq{idx}, 2, false);

        p = x_sel.samples;
        q = y_sel.samples;

        desc = sprintf('method %s, reference %s', method, reference);
        assertElementsAlmostEqual(p, q, ...
                                  'relative', 1e-6, desc);
        assertEqual(x_sel.fa, y_sel.fa);

    end

    assertEqual(x.a.fdim, y.a.fdim);

function test_meeg_baseline_correct_regression()
    interval = [-.15 -.04];
    expected_samples = {[2.7634  -2.6366   2.3689 -0.36892 -0.43043; ...
                         -0.49007  -1.0933  0.79082   1.2092   0.1044], ...
                        [-0.63503   1.3096 -0.49296  0.49296  0.51511; ...
                         -2.5308  -3.5553 -0.35529  0.35529  -1.5211], ...
                        [1.7634  -3.6366   1.3689  -1.3689  -1.4304
                         -1.4901  -2.0933 -0.20918  0.20918  -0.8956]};
    [ds, ds_ref] = get_test_dataset(interval);

    methods = {'relative', 'absolute', 'relchange'};

    ds_feature_msk = ds.fa.chan == 3 & ds.fa.freq == 2;

    for k = 1:numel(methods)
        method = methods{k};

        for j = 1:2
            if j == 1
                ref = interval;
            else
                ref = ds_ref;
            end

            ds_bl = cosmo_meeg_baseline_correct(ds, ref, method);
            ds_bl_msk = ds_bl.fa.chan == 3 & ds_bl.fa.freq == 2;
            d = cosmo_slice(ds_bl, ds_bl_msk, 2);

            assertElementsAlmostEqual(d.samples, expected_samples{k}, ...
                                      'absolute', 1e-4);
            d_fa = cosmo_slice(ds.fa, ds_feature_msk, 2, 'struct');
            assertEqual(d.fa, d_fa);
        end
    end

function [ds, ds_ref] = get_test_dataset(interval)
    ds = cosmo_synthetic_dataset('type', 'timefreq', 'size', 'big', ...
                                 'senstype', 'neuromag306_planar', ...
                                 'nchunks', 1);
    ds.sa = struct();
    ds.sa.rpt = (1:size(ds.samples, 1))';
    msk = ds.fa.chan <= 4 & ds.fa.freq <= 2;
    ds = cosmo_slice(ds, msk, 2);
    ds = cosmo_dim_prune(ds);

    matcher = @(x) interval(1) <= x & x <= interval(2);
    ds_ref = cosmo_slice(ds, cosmo_dim_match(ds, 'time', matcher), 2);
    ds_ref = cosmo_dim_prune(ds_ref);

function test_meeg_baseline_correct_nonmatching_sa
    ds = cosmo_synthetic_dataset('size', 'big', 'ntargets', 8, ...
                                 'nchunks', 1, 'type', 'timelock');
    nsamples = size(ds.samples, 1);

    while true
        rp = cosmo_randperm(nsamples);
        if ~isequal(rp, 1:nsamples)
            break
        end
    end
    ds_ref = cosmo_slice(ds, rp);
    assertExceptionThrown(@()cosmo_meeg_baseline_correct( ...
                                                         ds, ds_ref, 'relative'), '');

function test_meeg_baseline_correct_nonmatching_fa
    ds_big = cosmo_synthetic_dataset('size', 'big', 'ntargets', 8, ...
                                     'nchunks', 1, 'type', 'timefreq');
    ds = cosmo_slice(ds_big, ds_big.fa.chan <= 2 & ds_big.fa.freq <= 3, 2);
    ds_ref = cosmo_slice(ds_big, ds_big.fa.chan <= 3 & ds_big.fa.freq <= 2, 2);

    assertExceptionThrown(@()cosmo_meeg_baseline_correct( ...
                                                         ds, ds_ref, 'relative'), '');

function test_meeg_baseline_correct_illegal_inputs
    bc = @cosmo_meeg_baseline_correct;
    aet = @assertExceptionThrown;
    ds = cosmo_synthetic_dataset('type', 'timefreq', 'size', 'tiny');

    if cosmo_wtf('is_matlab')
        id_missing_arg = {'MATLAB:inputArgUndefined', 'MATLAB:minrhs'};
    else
        id_missing_arg = 'Octave:undefined-function';
    end

    aet(@()bc(ds, ds), id_missing_arg);
    aet(@()bc(ds, ds, 'foo'), '');

    aet(@()bc(ds, cosmo_slice(ds, 1), 'relative'), '');
    aet(@()bc(ds, cosmo_slice(ds, 1, 2), 'relative'), '');

    % test slicing
    bc(ds, cosmo_slice(ds, [1 2 3 4 5 6], 1), 'relative');
    aet(@()bc(ds, cosmo_slice(ds, [1 2 4 6 4 3], 1), 'relative'), '');