test correlation measure

function test_suite = test_correlation_measure
    % tests for cosmo_correlation_measure
    %
    % #   For CoSMoMVPA's copyright information and license terms,   #
    % #   see the COPYING file distributed with CoSMoMVPA.           #
    try % assignment of 'localfunctions' is necessary in Matlab >= 2016
        test_functions = localfunctions();
    catch % no problem; early Matlab versions can use initTestSuite fine
    end
    initTestSuite;

function test_correlation_measure_basis()
    % This test requires statistics functions
    cosmo_skip_test_if_no_external('#stats');

    ds3 = cosmo_synthetic_dataset('nchunks', 3, 'ntargets', 4);
    ds = cosmo_slice(ds3, ds3.sa.chunks <= 2);

    ds.sa.chunks = ds.sa.chunks + 10;
    ds.sa.targets = ds.sa.targets + 20;
    x = ds.samples(ds.sa.chunks == 11, :);
    y = ds.samples(ds.sa.chunks == 12, :);

    cxy = atanh(corr(x', y'));

    diag_msk = eye(4) > 0;
    c_diag = mean(cxy(diag_msk));
    c_off_diag = mean(cxy(~diag_msk));

    delta = c_diag - c_off_diag;

    c1 = cosmo_correlation_measure(ds);
    assertElementsAlmostEqual(delta, c1.samples, 'relative', 1e-5);
    assertEqual(c1.sa.labels, {'corr'});

    % reset state and do not show warnings
    orig_warning_state = cosmo_warning();
    warning_cleaner = onCleanup(@()cosmo_warning(orig_warning_state));
    cosmo_warning('reset');
    cosmo_warning('off');

    c2 = cosmo_correlation_measure(ds, 'output', 'correlation');
    assertElementsAlmostEqual(reshape(cxy, [], 1), c2.samples);
    assertEqual(kron((1:4)', ones(4, 1)), c2.sa.half2);
    assertEqual(repmat((1:4)', 4, 1), c2.sa.half1);

    i = 7;
    assertElementsAlmostEqual(cxy(c2.sa.half1(i), c2.sa.half2(i)), ...
                              c2.samples(i));

    assertEqual({'half1', 'half2'}, c2.a.sdim.labels);
    assertEqual({20 + (1:4)', 20 + (1:4)'}, c2.a.sdim.values);

    c4 = cosmo_correlation_measure(ds3, 'output', 'mean_by_fold');
    %
    for j = 1:3
        train_idxs = (3 - j) * 4 + (1:4);
        test_idxs = setdiff(1:12, train_idxs);

        ds_sel = ds3;
        ds_sel.sa.chunks(train_idxs) = 2;
        ds_sel.sa.chunks(test_idxs) = 1;

        c5 = cosmo_correlation_measure(ds_sel, 'output', 'mean');
        assertElementsAlmostEqual(c5.samples, c4.samples(j));
    end

    % test permutations
    ds4 = cosmo_synthetic_dataset('nchunks', 2, 'ntargets', 10);
    rp = randperm(20);

    ds4_perm = cosmo_slice(ds4, rp);
    assertEqual(cosmo_correlation_measure(ds4), ...
                cosmo_correlation_measure(ds4_perm));
    opt = struct();
    opt.output = 'correlation';
    assertEqual(cosmo_correlation_measure(ds4, opt), ...
                cosmo_correlation_measure(ds4_perm, opt));

function test_correlation_measure_single_target
    % This test requires statistics functions
    cosmo_skip_test_if_no_external('#stats');
    for ntargets = 2:6
        ds = cosmo_synthetic_dataset('nchunks', 2, 'ntargets', ntargets);
        ds.samples = randn(size(ds.samples));
        ds.sa.targets(:) = 1;

        idxs = cosmo_index_unique(mod(ds.sa.chunks, 2));
        assert(numel(idxs) == 2);

        x = mean(ds.samples(idxs{1}, :));
        y = mean(ds.samples(idxs{2}, :));
        r_xy = atanh(corr(x', y'));

        r_ds = cosmo_correlation_measure(ds, 'template', 1);

        assertElementsAlmostEqual(r_xy, r_ds.samples);
    end

function test_correlation_measure_regression()
    helper_test_correlation_measure_regression(false);

function test_correlation_measure_regression_spearman()
    if cosmo_skip_test_if_no_external('@stats')
        return
    end
    helper_test_correlation_measure_regression(true);

function helper_test_correlation_measure_regression(test_spearman)
    % reset state and do not show warnings
    orig_warning_state = cosmo_warning();
    warning_cleaner = onCleanup(@()cosmo_warning(orig_warning_state));
    cosmo_warning('reset');
    cosmo_warning('off');

    ds = cosmo_synthetic_dataset('ntargets', 3, 'nchunks', 5, 'sigma', .5);
    params = get_regression_test_params(test_spearman);

    n_params = numel(params);
    for k = 1:n_params
        param = params{k};

        args = param{1};
        samples = param{2};
        sa = param{3};
        sdim = param{4};
        res = cosmo_correlation_measure(ds, args{:});

        % test samples
        assertElementsAlmostEqual(res.samples, samples', 'absolute', 5e-3);

        % test sa
        keys = fieldnames(res.sa);
        assertEqual(sort(keys(:)), sort(sa(1:2:end))');
        for j = 1:2:numel(sa)
            key = sa{j};
            value = sa{j + 1};
            assertEqual(res.sa.(key), value(:));
        end

        % test sa
        if isempty(sdim)
            assertFalse(isfield(res, 'a'));
        else
            keys = fieldnames(res.a.sdim);
            assertEqual(sort(keys(:)), sort(sdim(1:2:end))');
            for j = 1:2:numel(sa)
                key = sdim{j};
                value = sdim{j + 1};
                sdim_value = res.a.sdim.(key);
                assertEqual(sdim_value(:), value(:));
            end
        end
    end

function params = get_regression_test_params(test_spearman)
    % contents
    % 1) input arguments
    % 2) samples
    % 3) sample attributes
    % 4) sdim
    if test_spearman
        params = {{{ 'corr_type' 'Spearman' }, ...
                   -0.228, ...
                   { 'labels' { 'corr' } }, ...
                   []}};
    else
        params = {{{ }, ...
                   -0.24, ...
                   { 'labels' { 'corr' } }, ...
                   []}, ...
                  {{ 'template' [-2 2 3; -1 1 2; 2 -4 -3] }, ...
                   2.48, ...
                   { 'labels' { 'corr' } }, ...
                   []}, ...
                  {{ 'merge_func' @(x)sum(abs(x), 1) }, ...
                   0.567, ...
                   { 'labels' { 'corr' } }, ...
                   []}, ...
                  {{ 'post_corr_func' @(x)x + 1 }, ...
                   -0.204, ...
                   { 'labels' { 'corr' } }, ...
                   []}, ...
                  {{ 'output' 'mean_by_fold' }, ...
                   [-0.289 -0.274 -0.532 -0.112 -0.269 ...
                    -0.0535 -0.203 -0.3 -0.198 -0.173], ...
                   { 'partition' [1 2 3 4 5 6 7 8 9 10] }, ...
                   []}, ...
                  {{ 'output' 'correlation' }, ...
                   [-0.649 0.0675 0.345 0.126 0.643 ...
                    0.413 0.266 0.399 0.0933], ...
                   { 'half1', [1 2 3 1 2 3 1 2 3], ...
                    'half2' [1 1 1 2 2 2 3 3 3] }, ...
                   { 'labels' { 'half1' 'half2' }, ...
                    'values' { [1 2 3]' [1 2 3]' } }
                  }};

    end

function test_correlation_measure_exceptions
    aet = @(varargin)assertExceptionThrown(@() ...
                                           cosmo_correlation_measure(varargin{:}), '');

    % reset state and do not show warnings
    orig_warning_state = cosmo_warning();
    warning_cleaner = onCleanup(@()cosmo_warning(orig_warning_state));
    cosmo_warning('reset');
    cosmo_warning('off');

    ds = cosmo_synthetic_dataset('nchunks', 2);
    aet(ds, 'template', eye(4));
    aet(ds, 'output', 'foo');
    aet(ds, 'output', 'one_minus_correlation');

    % single target throws exception
    ds.sa.targets(:) = 1;
    aet(ds);
    aet(ds, 'template', 2);
    aet(ds, 'template', eye(2));

    ds.sa.targets(1) = 2;
    aet(ds);

function x = identity(x)

function test_correlation_measure_warning_shown_if_no_defaults()
    orig_warning_state = cosmo_warning();
    cleaner = onCleanup(@()cosmo_warning(orig_warning_state));

    % reset state and do not show warnings
    cosmo_warning('reset');
    cosmo_warning('off');

    funcs = {[], @atanh};
    outputs = {[], 'mean', 'raw', 'correlation'};

    for k = 1:numel(funcs)
        func = funcs{k};
        for j = 1:numel(outputs)
            output = outputs{j};

            is_default_func = k <= 2;
            is_default_output = j <= 2;
            expect_warning = ~(is_default_func && is_default_output);

            opt = struct();
            if ~isempty(func)
                opt.post_corr_func = func;
            end

            if ~isempty(output)
                opt.output = output;
            end

            cosmo_warning('reset');
            cosmo_warning('off');

            ds = cosmo_synthetic_dataset('nchunks', 2);
            cosmo_correlation_measure(ds, opt);

            s = cosmo_warning();

            showed_warning = numel(s.shown_warnings) > 0;

            assertEqual(expect_warning, showed_warning);
        end
    end

function test_correlation_measure_wrong_template_size()
    ds = cosmo_synthetic_dataset('nchunks', 2, 'ntargets', 2);

    measure_args = struct();
    measure_args.template = [1 -1  0  0; ...
                             -1  1  0  0; ...
                             0  0  1 -1
                             0  0 -1  1];
    measure = @cosmo_correlation_measure;
    assertExceptionThrown(@()measure(ds, measure_args));