test correlation measure

function test_suite=test_correlation_measure
% tests for cosmo_correlation_measure
%
% #   For CoSMoMVPA's copyright information and license terms,   #
% #   see the COPYING file distributed with CoSMoMVPA.           #
    try % assignment of 'localfunctions' is necessary in Matlab >= 2016
        test_functions=localfunctions();
    catch % no problem; early Matlab versions can use initTestSuite fine
    end
    initTestSuite;

function test_correlation_measure_basis()
    ds3=cosmo_synthetic_dataset('nchunks',3,'ntargets',4);
    ds=cosmo_slice(ds3,ds3.sa.chunks<=2);

    ds.sa.chunks=ds.sa.chunks+10;
    ds.sa.targets=ds.sa.targets+20;
    x=ds.samples(ds.sa.chunks==11,:);
    y=ds.samples(ds.sa.chunks==12,:);

    cxy=atanh(corr(x',y'));

    diag_msk=eye(4)>0;
    c_diag=mean(cxy(diag_msk));
    c_off_diag=mean(cxy(~diag_msk));

    delta=c_diag-c_off_diag;

    c1=cosmo_correlation_measure(ds);
    assertElementsAlmostEqual(delta,c1.samples,'relative',1e-5);
    assertEqual(c1.sa.labels,{'corr'});

    % reset state and do not show warnings
    orig_warning_state=cosmo_warning();
    warning_cleaner=onCleanup(@()cosmo_warning(orig_warning_state));
    cosmo_warning('reset');
    cosmo_warning('off');

    c2=cosmo_correlation_measure(ds,'output','correlation');
    assertElementsAlmostEqual(reshape(cxy,[],1),c2.samples);
    assertEqual(kron((1:4)',ones(4,1)),c2.sa.half2);
    assertEqual(repmat((1:4)',4,1),c2.sa.half1);

    i=7;
    assertElementsAlmostEqual(cxy(c2.sa.half1(i),c2.sa.half2(i)),...
                                        c2.samples(i));

    assertEqual({'half1','half2'},c2.a.sdim.labels);
    assertEqual({20+(1:4)',20+(1:4)'},c2.a.sdim.values);

    c4=cosmo_correlation_measure(ds3,'output','mean_by_fold');
    %
    for j=1:3
        train_idxs=(3-j)*4+(1:4);
        test_idxs=setdiff(1:12,train_idxs);

        ds_sel=ds3;
        ds_sel.sa.chunks(train_idxs)=2;
        ds_sel.sa.chunks(test_idxs)=1;

        c5=cosmo_correlation_measure(ds_sel,'output','mean');
        assertElementsAlmostEqual(c5.samples, c4.samples(j));
    end

    % test permutations
    ds4=cosmo_synthetic_dataset('nchunks',2,'ntargets',10);
    rp=randperm(20);

    ds4_perm=cosmo_slice(ds4,rp);
    assertEqual(cosmo_correlation_measure(ds4),...
                    cosmo_correlation_measure(ds4_perm));
    opt=struct();
    opt.output='correlation';
    assertEqual(cosmo_correlation_measure(ds4,opt),...
                    cosmo_correlation_measure(ds4_perm,opt));

function test_correlation_measure_single_target
    for ntargets=2:6
        ds=cosmo_synthetic_dataset('nchunks',2,'ntargets',ntargets);
        ds.samples=randn(size(ds.samples));
        ds.sa.targets(:)=1;

        idxs=cosmo_index_unique(mod(ds.sa.chunks,2));
        assert(numel(idxs)==2);

        x=mean(ds.samples(idxs{1},:));
        y=mean(ds.samples(idxs{2},:));
        r_xy=atanh(corr(x',y'));

        r_ds=cosmo_correlation_measure(ds,'template',1);

        assertElementsAlmostEqual(r_xy, r_ds.samples);
    end




function test_correlation_measure_regression()
    helper_test_correlation_measure_regression(false);

function test_correlation_measure_regression_spearman()
    if cosmo_skip_test_if_no_external('@stats')
        return;
    end
    helper_test_correlation_measure_regression(true);

function helper_test_correlation_measure_regression(test_spearman)
    % reset state and do not show warnings
    orig_warning_state=cosmo_warning();
    warning_cleaner=onCleanup(@()cosmo_warning(orig_warning_state));
    cosmo_warning('reset');
    cosmo_warning('off');

    ds=cosmo_synthetic_dataset('ntargets',3,'nchunks',5,'sigma',.5);
    params=get_regression_test_params(test_spearman);

    n_params=numel(params);
    for k=1:n_params
        param=params{k};

        args=param{1};
        samples=param{2};
        sa=param{3};
        sdim=param{4};
        res=cosmo_correlation_measure(ds,args{:});

        % test samples
        assertElementsAlmostEqual(res.samples,samples','absolute',5e-3);

        % test sa
        keys=fieldnames(res.sa);
        assertEqual(sort(keys(:)), sort(sa(1:2:end))');
        for j=1:2:numel(sa)
            key=sa{j};
            value=sa{j+1};
            assertEqual(res.sa.(key),value(:));
        end

        % test sa
        if isempty(sdim)
            assertFalse(isfield(res,'a'))
        else
            keys=fieldnames(res.a.sdim);
            assertEqual(sort(keys(:)), sort(sdim(1:2:end))');
            for j=1:2:numel(sa)
                key=sdim{j};
                value=sdim{j+1};
                sdim_value=res.a.sdim.(key);
                assertEqual(sdim_value(:),value(:));
            end
        end
    end

function params=get_regression_test_params(test_spearman)
    % contents
    % 1) input arguments
    % 2) samples
    % 3) sample attributes
    % 4) sdim
    if test_spearman
        params={{{ 'corr_type' 'Spearman' },...
                -0.228,...
                { 'labels' { 'corr' } },...
                []}};
    else
        params={{{ },...
                    -0.24,...
                    { 'labels' { 'corr' } },...
                    []},...
                {{ 'template' [ -2 2 3; -1 1 2; 2 -4 -3  ] },...
                    2.48,...
                    { 'labels' { 'corr' } },...
                    []},...
                {{ 'merge_func' @(x)sum(abs(x),1) },...
                    0.567,...
                    { 'labels' { 'corr' } },...
                    []},...
                {{ 'post_corr_func' @(x)x+1 },...
                    -0.204,...
                    { 'labels' { 'corr' } },...
                    []},...
                {{ 'output' 'mean_by_fold' },...
                    [ -0.289 -0.274 -0.532 -0.112 -0.269 ...
                                -0.0535 -0.203 -0.3 -0.198 -0.173 ],...
                    { 'partition' [ 1 2 3 4 5 6 7 8 9 10 ] },...
                    []},...
                {{ 'output' 'correlation' },...
                    [ -0.649 0.0675 0.345 0.126 0.643 ...
                                 0.413 0.266 0.399 0.0933 ],...
                    { 'half1', [ 1 2 3 1 2 3 1 2 3 ],...
                                'half2' [ 1 1 1 2 2 2 3 3 3 ] },...
                    { 'labels' { 'half1' 'half2' },...
                                'values' { [ 1 2 3 ]' [ 1 2 3 ]' } }
                }};

    end



function test_correlation_measure_exceptions
    aet=@(varargin)assertExceptionThrown(@()...
                cosmo_correlation_measure(varargin{:}),'');

    % reset state and do not show warnings
    orig_warning_state=cosmo_warning();
    warning_cleaner=onCleanup(@()cosmo_warning(orig_warning_state));
    cosmo_warning('reset');
    cosmo_warning('off');

    ds=cosmo_synthetic_dataset('nchunks',2);
    aet(ds,'template',eye(4));
    aet(ds,'output','foo');
    aet(ds,'output','one_minus_correlation');

    % single target throws exception
    ds.sa.targets(:)=1;
    aet(ds);
    aet(ds,'template',2);
    aet(ds,'template',eye(2));


    ds.sa.targets(1)=2;
    aet(ds);

function x=identity(x)

function test_correlation_measure_warning_shown_if_no_defaults()
    orig_warning_state=cosmo_warning();
    cleaner=onCleanup(@()cosmo_warning(orig_warning_state));

    % reset state and do not show warnings
    cosmo_warning('reset');
    cosmo_warning('off');

    funcs={[],@atanh};
    outputs={[],'mean','raw','correlation'};



    for k=1:numel(funcs)
        func=funcs{k};
        for j=1:numel(outputs)
            output=outputs{j};

            is_default_func=k<=2;
            is_default_output=j<=2;
            expect_warning=~(is_default_func && is_default_output);

            opt=struct();
            if ~isempty(func)
                opt.post_corr_func=func;
            end

            if ~isempty(output)
                opt.output=output;
            end

            cosmo_warning('reset');
            cosmo_warning('off');

            ds=cosmo_synthetic_dataset('nchunks',2);
            cosmo_correlation_measure(ds,opt);

            s=cosmo_warning();

            showed_warning=numel(s.shown_warnings)>0;

            assertEqual(expect_warning,showed_warning);
        end
    end

function test_correlation_measure_wrong_template_size()
    ds=cosmo_synthetic_dataset('nchunks',2,'ntargets',2);

    measure_args=struct();
    measure_args.template = [ 1 -1  0  0 ; ...
                             -1  1  0  0 ; ...
                              0  0  1 -1 ;
                              0  0 -1  1 ];
    measure=@cosmo_correlation_measure;
    assertExceptionThrown(@()measure(ds,measure_args));