test target dsm corr measure

function test_suite = test_target_dsm_corr_measure
    % tests for cosmo_target_dsm_corr_measure
    %
    % #   For CoSMoMVPA's copyright information and license terms,   #
    % #   see the COPYING file distributed with CoSMoMVPA.           #
    try % assignment of 'localfunctions' is necessary in Matlab >= 2016
        test_functions = localfunctions();
    catch % no problem; early Matlab versions can use initTestSuite fine
    end
    initTestSuite;

function test_target_dsm_corr_measure_pearson
    ds = cosmo_synthetic_dataset('ntargets', 6, 'nchunks', 1);
    mat1 = [1 2 3 4 2 3 2 1 2 3 1 2 3 2 2];

    dcm1 = cosmo_target_dsm_corr_measure(ds, 'target_dsm', mat1);
    assertElementsAlmostEqual(dcm1.samples, 0.2507, 'absolute', 1e-4);
    assertEqual(dcm1.sa.labels, {'rho'});
    assertEqual(dcm1.sa.metric, {'correlation'});
    assertEqual(dcm1.sa.type, {'Pearson'});

    distance_ds = cosmo_pdist(ds.samples, 'correlation');
    assertElementsAlmostEqual(cosmo_corr(distance_ds', mat1'), dcm1.samples);

    sq1 = cosmo_squareform(mat1);
    dcm2 = cosmo_target_dsm_corr_measure(ds, 'target_dsm', sq1);
    assertElementsAlmostEqual(dcm1.samples, dcm2.samples);
    dcm2.samples = dcm1.samples;
    assertEqual(dcm1, dcm2);

    dcm3 = cosmo_target_dsm_corr_measure(ds, 'target_dsm', sq1, ...
                                         'metric', 'euclidean');
    assertElementsAlmostEqual(dcm3.samples, 0.3037, 'absolute', 1e-4);

function test_target_dsm_corr_measure_partial
    ds = cosmo_synthetic_dataset('ntargets', 6, 'nchunks', 1);
    mat1 = [1 2 3 4 2 3 2 1 2 3 1 2 3 2 2];
    mat2 = mat1(end:-1:1);

    dcm1 = cosmo_target_dsm_corr_measure(ds, 'target_dsm', mat1, ...
                                         'regress_dsm', mat2);
    assertElementsAlmostEqual(dcm1.samples, 0.3082, 'absolute', 1e-4);

function test_target_dsm_corr_measure_partial_vector_partialcorr
    if cosmo_skip_test_if_no_external('!partialcorr')
        return
    end
    ntargets = ceil(rand() * 6 + 6);

    ds = cosmo_synthetic_dataset('ntargets', ntargets, 'nchunks', 1);
    ds.samples(:, :) = randn(size(ds.samples));

    ncombi = ntargets * (ntargets - 1) / 2;
    vec1 = randn(1, ncombi);
    vec2 = randn(1, ncombi);

    dcm1 = cosmo_target_dsm_corr_measure(ds, 'target_dsm', vec1, ...
                                         'regress_dsm', vec2);
    distance = cosmo_pdist(ds.samples, 'correlation');
    pcorr = partialcorr(distance', vec1', vec2');

    assertElementsAlmostEqual(dcm1.samples, pcorr);

    mat1 = cosmo_squareform(vec1);
    mat2 = cosmo_squareform(vec2);

    dcm2 = cosmo_target_dsm_corr_measure(ds, 'target_dsm', mat1, ...
                                         'regress_dsm', mat2);
    assertElementsAlmostEqual(dcm1.samples, dcm2.samples);

    dcm1_s = cosmo_target_dsm_corr_measure(ds, 'target_dsm', vec1, ...
                                           'regress_dsm', vec2, ...
                                           'type', 'Spearman');
    pcorr_s = partialcorr(distance', vec1', vec2', 'type', 'Spearman');

    assertElementsAlmostEqual(dcm1_s.samples, pcorr_s);

    dcm2_s = cosmo_target_dsm_corr_measure(ds, 'target_dsm', mat1, ...
                                           'regress_dsm', mat2, ...
                                           'type', 'Spearman');
    assertElementsAlmostEqual(dcm1_s.samples, dcm2_s.samples);

function test_target_dsm_corr_measure_partial_cell_partialcorr
    if cosmo_skip_test_if_no_external('!partialcorr')
        return
    end
    ntargets = ceil(rand() * 6 + 6);
    ds = cosmo_synthetic_dataset('ntargets', ntargets, 'nchunks', 1);
    ds.samples(:, :) = randn(size(ds.samples));

    ncombi = ntargets * (ntargets - 1) / 2;
    vec1 = randn(1, ncombi);

    % set up regression elements
    n_regress = ceil(rand() * 5 + 3);
    regress_vec_cell = cell(1, n_regress);
    regress_mx = zeros(ncombi, n_regress);
    for k = 1:n_regress
        v = randn(1, ncombi);
        regress_vec_cell{k} = v;
        regress_mx(:, k) = v;
    end

    dcm1 = cosmo_target_dsm_corr_measure(ds, 'target_dsm', vec1, ...
                                         'regress_dsm', regress_vec_cell);
    distance = cosmo_pdist(ds.samples, 'correlation');
    pcorr = partialcorr(distance', vec1', regress_mx);

    assertElementsAlmostEqual(dcm1.samples, pcorr);

    mat1 = cosmo_squareform(vec1);

    regress_mx_cell = cell(size(regress_vec_cell));
    for k = 1:n_regress
        regress_mx_cell{k} = squareform(regress_vec_cell{k});
    end

    dcm2 = cosmo_target_dsm_corr_measure(ds, 'target_dsm', mat1, ...
                                         'regress_dsm', regress_mx_cell);
    assertElementsAlmostEqual(dcm1.samples, dcm2.samples);

function test_target_dsm_corr_measure_partial_regression
    ds = cosmo_synthetic_dataset('ntargets', 6, 'nchunks', 1);
    vec1 = [1 2 3 4 2 3 2 1 2 3 1 2 3 2 2];
    vec2 = vec1(end:-1:1);

    dcm1 = cosmo_target_dsm_corr_measure(ds, 'target_dsm', vec1, ...
                                         'regress_dsm', vec2);
    assertElementsAlmostEqual(dcm1.samples, 0.3082, 'absolute', 1e-4);
    mat1 = cosmo_squareform(vec1);
    mat2 = cosmo_squareform(vec2);

    dcm2 = cosmo_target_dsm_corr_measure(ds, 'target_dsm', mat1, ...
                                         'regress_dsm', mat2);
    assertElementsAlmostEqual(dcm1.samples, dcm2.samples);

function test_target_dsm_corr_measure_partial_no_correlation
    t = [0 1 0 0
         0 0 -1 0
         0 0 0 0
         0 0 0 0];
    r = [0 0 1 0
         0 0 0 0
         0 0 0 0
         0 0 0 0];

    t = t + t';
    r = r + r';

    msk = triu(ones(size(t)), 1) > 0;
    c = cosmo_corr(t(msk), r(msk));
    assertElementsAlmostEqual(c, 0); % uncorrelated

    ds = cosmo_synthetic_dataset('nchunks', 1, 'ntargets', 4);

    t_base = cosmo_target_dsm_corr_measure(ds, 'target_dsm', t);
    t_regress = cosmo_target_dsm_corr_measure(ds, 'target_dsm', t, ...
                                              'regress_dsm', r);

    assertElementsAlmostEqual(t_base.samples, -0.6615, 'absolute', 1e-4);
    assertElementsAlmostEqual(t_regress.samples, -0.7410, 'absolute', 1e-4);

function test_target_dsm_corr_measure_non_pearson
    % test non-Pearson correlations
    ds = cosmo_synthetic_dataset('ntargets', 6, 'nchunks', 1);
    mat1 = [1 2 3 4 2 3 2 1 2 3 1 2 3 2 2];

    dcm1 = cosmo_target_dsm_corr_measure(ds, 'target_dsm', mat1, ...
                                         'type', 'Spearman');

    assertElementsAlmostEqual(dcm1.samples, 0.2558, 'absolute', 1e-4);

function test_target_dsm_corr_measure_glm_dsm
    ds = cosmo_synthetic_dataset('ntargets', 6, 'nchunks', 1);
    ds_vec_row = cosmo_pdist(ds.samples, 'correlation');
    ds_vec = cosmo_normalize(ds_vec_row(:), 'zscore');

    mat1 = [1 2 3 4 2 3 2 1 2 3 1 2 3 2 2];
    mat2 = mat1;
    mat2(1:10) = mat2(10:-1:1);

    nrows = numel(mat1);
    design_matrix = [cosmo_normalize([mat1', mat2'], 'zscore'), ones(nrows, 1)];

    betas = design_matrix \ ds_vec;

    dcm1 = cosmo_target_dsm_corr_measure(ds, 'glm_dsm', {mat1, mat2});

    assertElementsAlmostEqual(dcm1.samples, betas(1:2));
    assertElementsAlmostEqual(dcm1.samples, [0.3505; 0.3994], ...
                              'absolute', 1e-4);
    sa = cosmo_structjoin('labels', {'beta1'; 'beta2'}, ...
                          'metric', {'correlation'; 'correlation'});
    assertEqual(dcm1.sa, sa);

    mat2 = cosmo_squareform(mat2);
    dcm2 = cosmo_target_dsm_corr_measure(ds, 'glm_dsm', {1 + 3 * mat1, 2 * mat2});
    assertElementsAlmostEqual(dcm1.samples, dcm2.samples);
    assertEqual(dcm1.sa, dcm2.sa);

    dcm3 = cosmo_target_dsm_corr_measure(ds, 'glm_dsm', {mat2, mat1});
    assertElementsAlmostEqual(dcm1.samples([2 1], :), dcm3.samples);
    assertEqual(dcm1.sa, dcm3.sa);

function test_target_dsm_corr_measure_glm_dsm_matlab_correspondence
    if cosmo_skip_test_if_no_external('@stats')
        return
    end
    ds = cosmo_synthetic_dataset('ntargets', 6, 'nchunks', 1);
    ds.samples = randn(size(ds.samples));
    ds_vec_row = cosmo_pdist(ds.samples, 'correlation');
    ds_vec = cosmo_normalize(ds_vec_row(:), 'zscore');

    mat1 = rand(1, 15);
    mat2 = rand(1, 15);

    design_matrix = cosmo_normalize([mat1', mat2'], 'zscore');

    beta = regress(ds_vec, design_matrix);
    ds_beta = cosmo_target_dsm_corr_measure(ds, 'glm_dsm', {mat1, mat2});
    assertElementsAlmostEqual(beta, ds_beta.samples);

function test_target_dsm_random_data_with_cosmo_functions
    ntargets = ceil(rand() * 5 + 5);
    nfeatures = ceil(rand() * 20 + 30);

    ds = struct();
    ds.sa.targets = (1:ntargets)';
    ds.sa.chunks = ceil(rand() * 10 + 3);

    % assume working pdist (tested elsewhere)
    make_rand_dsm = @()cosmo_pdist(randn(ntargets, 2 * nfeatures));
    target_dsm = make_rand_dsm();
    glm_dsm = {make_rand_dsm(), make_rand_dsm(), make_rand_dsm()};

    for num_glms = 0:numel(glm_dsm)
        for center_data = [-1, 0, 1]
            for use_mask = [-1, 0, 1]
                ds.samples = randn(ntargets, nfeatures);
                samples = ds.samples;

                opt = struct();

                % optionally, center data
                if center_data > 0
                    opt.center_data = logical(center_data);

                    if opt.center_data
                        samples = bsxfun(@minus, samples, mean(samples, 1));
                    end
                end

                % compute pdist for samples
                c = cosmo_squareform(1 - cosmo_corr(samples'));
                n_pairs = numel(c);
                if use_mask > 0
                    while true
                        msk = rand(n_pairs, 1) > .5;

                        if sum(msk) > 3
                            break
                        end
                    end
                else
                    msk = true(n_pairs, 1);
                end

                if num_glms == 0
                    opt.target_dsm = target_dsm;
                    opt.target_dsm(~msk) = NaN;
                    expected_samples = cosmo_corr(c(msk)', ...
                                                  opt.target_dsm(msk)');
                else
                    opt.glm_dsm = glm_dsm(1:num_glms);
                    for k = 1:num_glms
                        opt.glm_dsm{k}(~msk) = NaN;
                    end

                    glm_mat = cat(1, opt.glm_dsm{:})';
                    glm_z = helper_quick_zscore(glm_mat(msk, :));
                    c_z = helper_quick_zscore(c(msk)');
                    expected_samples = glm_z \ c_z;
                end

                result = cosmo_target_dsm_corr_measure(ds, opt);
                assertElementsAlmostEqual(result.samples, expected_samples);
            end
        end
    end

function mat_z = helper_quick_zscore(mat)
    mat_c = bsxfun(@minus, mat, mean(mat, 1));
    mat_z = bsxfun(@rdivide, mat_c, std(mat_c, [], 1));

function test_target_dsm_corr_measure_mask_exceptions
    ntargets = 6;
    ds = cosmo_synthetic_dataset('ntargets', ntargets, 'nchunks', 1);

    npairs = ntargets * (ntargets - 1) / 2;
    for num_non_nan = 4:npairs
        nan_msk = true(npairs, 1);
        rp = randperm(npairs);
        nan_msk(rp(1:num_non_nan)) = false;

        for num_glms = -1:3
            opt = struct();

            if num_glms == -1
                opt.regress_dsm = randn(npairs, 1);
                opt.regress_dsm(nan_msk) = NaN;
            end

            if num_glms <= 0
                opt.target_dsm = randn(npairs, 1);
                opt.target_dsm(nan_msk) = NaN;
            else
                opt.glm_dsm = cell(num_glms, 1);
                for k = 1:num_glms
                    opt.glm_dsm{k} = randn(npairs, 1);
                    opt.glm_dsm{k}(nan_msk) = NaN;
                end
            end

            for set_inconsistent_non_nan_msk = [false, true]
                key_cell = intersect(fieldnames(opt), ...
                                     {'glm_dsm', 'regress_dsm'});

                if set_inconsistent_non_nan_msk && ...
                            (isempty(key_cell) || any(num_glms == [0, 1]))
                    % skip
                    continue
                end

                if set_inconsistent_non_nan_msk
                    assert(numel(key_cell) == 1);
                    key = key_cell{1};

                    % swap true and false value in one of the matrices

                    value = opt.(key);

                    value_is_cell = iscell(value);
                    if value_is_cell
                        glm_idx = num_glms;
                        value = value{glm_idx};
                    end

                    i = find(isnan(value), 1);
                    j = find(~isnan(value), 1);

                    value(i) = value(j);
                    value(j) = NaN;

                    if value_is_cell
                        opt.(key){glm_idx} = value;
                    else
                        opt.(key) = value;
                    end
                end

                func_handle = @()cosmo_target_dsm_corr_measure(ds, opt);

                expect_error = set_inconsistent_non_nan_msk;
                if expect_error
                    assertExceptionThrown(func_handle, '');
                else
                    % should be ok
                    func_handle();
                end
            end
        end
    end

    % test exceptions
function test_target_dsm_corr_measure_exceptions
    ds = cosmo_synthetic_dataset('ntargets', 6, 'nchunks', 1);
    mat1 = [1 2 3 4 2 3 2 1 2 3 1 2 3 2 2];

    aet = @(varargin)assertExceptionThrown( ...
                                           @()cosmo_target_dsm_corr_measure(varargin{:}), '');
    aet(struct, mat1);
    aet(ds);
    aet(ds, 'target_dsm', [mat1 1]);
    aet(ds, 'target_dsm', eye(6));
    aet(ds, 'target_dsm', zeros(7));

    aet(ds, 'target_dsm', mat1, 'glm_dsm', {mat1});
    aet(ds, 'regress_dsm', mat1, 'glm_dsm', {mat1});
    aet(ds, 'target_dsm', mat1, 'glm_dsm', repmat({mat1}, 15, 1));
    aet(ds, 'regress_dsm', mat1, 'glm_dsm', repmat({mat1}, 15, 1));
    aet(ds, 'glm_dsm', struct());
    aet(ds, 'glm_dsm', {[mat1 1]});

    mat2_ds = struct();
    mat2_ds.samples = mat1;

    % requires numeric input
    mat2_ds_rep = repmat({mat2_ds}, 1, 15);
    mat2_ds_stacked = cat(1, mat2_ds_rep{:});
    aet(ds, 'target_dsm', mat2_ds_stacked);

    % illegal correlation type
    aet(ds, 'target_dsm', mat1, 'type', 'foo');
    aet(ds, 'target_dsm', mat1, 'type', 2);

    % Spearman or Kendall not allowed when using glm_dsm
    aet(ds, 'glm_dsm', mat1, 'type', 'Spearman');
    aet(ds, 'glm_dsm', mat1, 'type', 'Kendall');

    % Kendall not allowed with regress_dsm
    aet(ds, 'target_dsm', mat1, 'regress_dsm', {mat1}, 'type', 'Kendall');

function test_target_dsm_corr_measure_warnings_zero()
    ds = cosmo_synthetic_dataset('ntargets', 3, 'nchunks', 1);

    ds_zero = ds;
    ds_zero.samples(:) = 0;

    helper_target_dsm_corr_measure_with_warning(ds_zero, ...
                                                'target_dsm', [1 2 3]);

function test_target_dsm_corr_measure_warnings_nan()
    ds = cosmo_synthetic_dataset('ntargets', 3, 'nchunks', 1);

    ds_zero = ds;
    ds_zero.samples(1) = NaN;

    helper_target_dsm_corr_measure_with_warning(ds_zero, ...
                                                'target_dsm', [1 2 3]);

function test_target_dsm_corr_measure_warnings_constant()
    ds = cosmo_synthetic_dataset('ntargets', 3, 'nchunks', 1);

    ds_zero = ds;

    helper_target_dsm_corr_measure_with_warning(ds_zero, ...
                                                'target_dsm', [0 0 0]);

function result = helper_target_dsm_corr_measure_with_warning(varargin)
    % ensure to reset to original state when leaving this function
    warning_state = cosmo_warning();
    cleaner = onCleanup(@()cosmo_warning(warning_state));

    % clear all warnings
    empty_state = warning_state;
    empty_state.shown_warnings = {};
    cosmo_warning(empty_state);
    cosmo_warning('off');

    result = cosmo_target_dsm_corr_measure(varargin{:});
    w = cosmo_warning();
    assert(numel(w.shown_warnings) > 0);
    assert(iscellstr(w.shown_warnings));