test target dsm corr measure

function test_suite=test_target_dsm_corr_measure
% tests for cosmo_target_dsm_corr_measure
%
% #   For CoSMoMVPA's copyright information and license terms,   #
% #   see the COPYING file distributed with CoSMoMVPA.           #
    try % assignment of 'localfunctions' is necessary in Matlab >= 2016
        test_functions=localfunctions();
    catch % no problem; early Matlab versions can use initTestSuite fine
    end
    initTestSuite;

function test_target_dsm_corr_measure_pearson
    ds=cosmo_synthetic_dataset('ntargets',6,'nchunks',1);
    mat1=[1 2 3 4 2 3 2 1 2 3 1 2 3 2 2];

    dcm1=cosmo_target_dsm_corr_measure(ds,'target_dsm',mat1);
    assertElementsAlmostEqual(dcm1.samples,0.2507,'absolute',1e-4);
    assertEqual(dcm1.sa.labels,{'rho'});
    assertEqual(dcm1.sa.metric,{'correlation'});
    assertEqual(dcm1.sa.type,{'Pearson'});

    distance_ds=cosmo_pdist(ds.samples,'correlation');
    assertElementsAlmostEqual(cosmo_corr(distance_ds',mat1'),dcm1.samples);

    sq1=cosmo_squareform(mat1);
    dcm2=cosmo_target_dsm_corr_measure(ds,'target_dsm',sq1);
    assertElementsAlmostEqual(dcm1.samples,dcm2.samples);
    dcm2.samples=dcm1.samples;
    assertEqual(dcm1,dcm2);

    dcm3=cosmo_target_dsm_corr_measure(ds,'target_dsm',sq1,...
                                                'metric','euclidean');
    assertElementsAlmostEqual(dcm3.samples,0.3037,'absolute',1e-4);


function test_target_dsm_corr_measure_partial
    ds=cosmo_synthetic_dataset('ntargets',6,'nchunks',1);
    mat1=[1 2 3 4 2 3 2 1 2 3 1 2 3 2 2];
    mat2=mat1(end:-1:1);

    dcm1=cosmo_target_dsm_corr_measure(ds,'target_dsm',mat1,...
                                                'regress_dsm',mat2);
    assertElementsAlmostEqual(dcm1.samples,0.3082,'absolute',1e-4);


function test_target_dsm_corr_measure_partial_vector_partialcorr
    if cosmo_skip_test_if_no_external('!partialcorr')
        return;
    end
    ntargets=ceil(rand()*6+6);

    ds=cosmo_synthetic_dataset('ntargets',ntargets,'nchunks',1);
    ds.samples(:,:)=randn(size(ds.samples));

    ncombi=ntargets*(ntargets-1)/2;
    vec1=randn(1,ncombi);
    vec2=randn(1,ncombi);

    dcm1=cosmo_target_dsm_corr_measure(ds,'target_dsm',vec1,...
                                                'regress_dsm',vec2);
    distance=cosmo_pdist(ds.samples,'correlation');
    pcorr=partialcorr(distance',vec1',vec2');

    assertElementsAlmostEqual(dcm1.samples,pcorr);

    mat1=cosmo_squareform(vec1);
    mat2=cosmo_squareform(vec2);

    dcm2=cosmo_target_dsm_corr_measure(ds,'target_dsm',mat1,...
                                                'regress_dsm',mat2);
    assertElementsAlmostEqual(dcm1.samples,dcm2.samples);

    dcm1_s=cosmo_target_dsm_corr_measure(ds,'target_dsm',vec1,...
                                                'regress_dsm',vec2,...
                                                'type','Spearman');
    pcorr_s=partialcorr(distance',vec1',vec2','type','Spearman');

    assertElementsAlmostEqual(dcm1_s.samples,pcorr_s);

    dcm2_s=cosmo_target_dsm_corr_measure(ds,'target_dsm',mat1,...
                                                'regress_dsm',mat2,...
                                                'type','Spearman');
    assertElementsAlmostEqual(dcm1_s.samples,dcm2_s.samples);


function test_target_dsm_corr_measure_partial_cell_partialcorr
    if cosmo_skip_test_if_no_external('!partialcorr')
        return;
    end
    ntargets=ceil(rand()*6+6);
    ds=cosmo_synthetic_dataset('ntargets',ntargets,'nchunks',1);
    ds.samples(:,:)=randn(size(ds.samples));

    ncombi=ntargets*(ntargets-1)/2;
    vec1=randn(1,ncombi);

    % set up regression elements
    n_regress=ceil(rand()*5+3);
    regress_vec_cell=cell(1,n_regress);
    regress_mx=zeros(ncombi,n_regress);
    for k=1:n_regress
        v=randn(1,ncombi);
        regress_vec_cell{k}=v;
        regress_mx(:,k)=v;
    end

    dcm1=cosmo_target_dsm_corr_measure(ds,'target_dsm',vec1,...
                                           'regress_dsm',regress_vec_cell);
    distance=cosmo_pdist(ds.samples,'correlation');
    pcorr=partialcorr(distance',vec1',regress_mx);

    assertElementsAlmostEqual(dcm1.samples,pcorr);

    mat1=cosmo_squareform(vec1);

    regress_mx_cell=cell(size(regress_vec_cell));
    for k=1:n_regress
        regress_mx_cell{k}=squareform(regress_vec_cell{k});
    end

    dcm2=cosmo_target_dsm_corr_measure(ds,'target_dsm',mat1,...
                                            'regress_dsm',regress_mx_cell);
    assertElementsAlmostEqual(dcm1.samples,dcm2.samples);

function test_target_dsm_corr_measure_partial_regression
    ds=cosmo_synthetic_dataset('ntargets',6,'nchunks',1);
    vec1=[1 2 3 4 2 3 2 1 2 3 1 2 3 2 2];
    vec2=vec1(end:-1:1);

    dcm1=cosmo_target_dsm_corr_measure(ds,'target_dsm',vec1,...
                                                'regress_dsm',vec2);
    assertElementsAlmostEqual(dcm1.samples,0.3082,'absolute',1e-4);
    mat1=cosmo_squareform(vec1);
    mat2=cosmo_squareform(vec2);

    dcm2=cosmo_target_dsm_corr_measure(ds,'target_dsm',mat1,...
                                                'regress_dsm',mat2);
    assertElementsAlmostEqual(dcm1.samples,dcm2.samples);


function test_target_dsm_corr_measure_partial_no_correlation
    t=[0 1 0 0;
       0 0 -1 0;
       0 0 0 0;
       0 0 0 0];
    r=[0 0 1 0;
       0 0 0 0;
       0 0 0 0;
       0 0 0 0];

    t=t+t';
    r=r+r';

    msk=triu(ones(size(t)),1)>0;
    c=cosmo_corr(t(msk),r(msk));
    assertElementsAlmostEqual(c,0); % uncorrelated

    ds=cosmo_synthetic_dataset('nchunks',1,'ntargets',4);

    t_base=cosmo_target_dsm_corr_measure(ds,'target_dsm',t);
    t_regress=cosmo_target_dsm_corr_measure(ds,'target_dsm',t,...
                            'regress_dsm',r);


    assertElementsAlmostEqual(t_base.samples,-0.6615,'absolute',1e-4);
    assertElementsAlmostEqual(t_regress.samples,-0.7410,'absolute',1e-4);


function test_target_dsm_corr_measure_non_pearson
    % test non-Pearson correlations
    ds=cosmo_synthetic_dataset('ntargets',6,'nchunks',1);
    mat1=[1 2 3 4 2 3 2 1 2 3 1 2 3 2 2];

    dcm1=cosmo_target_dsm_corr_measure(ds,'target_dsm',mat1,...
                                                'type','Spearman');

    assertElementsAlmostEqual(dcm1.samples,0.2558,'absolute',1e-4);


function test_target_dsm_corr_measure_glm_dsm
    ds=cosmo_synthetic_dataset('ntargets',6,'nchunks',1);
    ds_vec_row=cosmo_pdist(ds.samples,'correlation');
    ds_vec=cosmo_normalize(ds_vec_row(:),'zscore');

    mat1=[1 2 3 4 2 3 2 1 2 3 1 2 3 2 2];
    mat2=mat1;
    mat2(1:10)=mat2(10:-1:1);

    nrows=numel(mat1);
    design_matrix=[cosmo_normalize([mat1',mat2'],'zscore'),ones(nrows,1)];

    betas=design_matrix \ ds_vec;

    dcm1=cosmo_target_dsm_corr_measure(ds,'glm_dsm',{mat1,mat2});

    assertElementsAlmostEqual(dcm1.samples,betas(1:2));
    assertElementsAlmostEqual(dcm1.samples,[0.3505;0.3994],...
                                    'absolute',1e-4);
    sa=cosmo_structjoin('labels',{'beta1';'beta2'},...
                            'metric',{'correlation';'correlation'});
    assertEqual(dcm1.sa,sa);

    mat2=cosmo_squareform(mat2);
    dcm2=cosmo_target_dsm_corr_measure(ds,'glm_dsm',{1+3*mat1,2*mat2});
    assertElementsAlmostEqual(dcm1.samples,dcm2.samples);
    assertEqual(dcm1.sa,dcm2.sa);

    dcm3=cosmo_target_dsm_corr_measure(ds,'glm_dsm',{mat2,mat1});
    assertElementsAlmostEqual(dcm1.samples([2 1],:), dcm3.samples);
    assertEqual(dcm1.sa,dcm3.sa);


function test_target_dsm_corr_measure_glm_dsm_matlab_correspondence
    if cosmo_skip_test_if_no_external('@stats')
        return;
    end
    ds=cosmo_synthetic_dataset('ntargets',6,'nchunks',1);
    ds.samples=randn(size(ds.samples));
    ds_vec_row=cosmo_pdist(ds.samples,'correlation');
    ds_vec=cosmo_normalize(ds_vec_row(:),'zscore');

    mat1=rand(1,15);
    mat2=rand(1,15);

    design_matrix=cosmo_normalize([mat1',mat2'],'zscore');

    beta=regress(ds_vec, design_matrix);
    ds_beta=cosmo_target_dsm_corr_measure(ds,'glm_dsm',{mat1,mat2});
    assertElementsAlmostEqual(beta, ds_beta.samples);

function test_target_dsm_random_data_with_cosmo_functions
    ntargets=ceil(rand()*5+5);
    nfeatures=ceil(rand()*20+30);

    ds=struct();
    ds.sa.targets=(1:ntargets)';
    ds.sa.chunks=ceil(rand()*10+3);

    % assume working pdist (tested elsewhere)
    make_rand_dsm=@()cosmo_pdist(randn(ntargets,2*nfeatures));
    target_dsm=make_rand_dsm();
    glm_dsm={make_rand_dsm(),make_rand_dsm(),make_rand_dsm()};

    for num_glms=0:numel(glm_dsm)
        for center_data=[-1,0,1]
            for use_mask=[-1,0,1]
                ds.samples=randn(ntargets,nfeatures);
                samples=ds.samples;

                opt=struct();

                % optionally, center data
                if center_data>0
                    opt.center_data=logical(center_data);

                    if opt.center_data
                        samples=bsxfun(@minus,samples,mean(samples,1));
                    end
                end

                % compute pdist for samples
                c=cosmo_squareform(1-cosmo_corr(samples'));
                n_pairs=numel(c);
                if use_mask>0
                    while true
                        msk=rand(n_pairs,1)>.5;

                        if sum(msk)>3
                            break;
                        end
                    end
                else
                    msk=true(n_pairs,1);
                end

                if num_glms==0
                    opt.target_dsm=target_dsm;
                    opt.target_dsm(~msk)=NaN;
                    expected_samples=cosmo_corr(c(msk)',...
                                            opt.target_dsm(msk)');
                else
                    opt.glm_dsm=glm_dsm(1:num_glms);
                    for k=1:num_glms
                        opt.glm_dsm{k}(~msk)=NaN;
                    end

                    glm_mat=cat(1,opt.glm_dsm{:})';
                    glm_z=helper_quick_zscore(glm_mat(msk,:));
                    c_z=helper_quick_zscore(c(msk)');
                    expected_samples=glm_z \ c_z;
                end

                result=cosmo_target_dsm_corr_measure(ds,opt);
                assertElementsAlmostEqual(result.samples,expected_samples)
            end
        end
    end



function mat_z=helper_quick_zscore(mat)
    mat_c=bsxfun(@minus,mat,mean(mat,1));
    mat_z=bsxfun(@rdivide,mat_c,std(mat_c,[],1));


function test_target_dsm_corr_measure_mask_exceptions
    ntargets=6;
    ds=cosmo_synthetic_dataset('ntargets',ntargets,'nchunks',1);

    npairs=ntargets*(ntargets-1)/2;
    for num_non_nan=4:npairs
        nan_msk=true(npairs,1);
        rp=randperm(npairs);
        nan_msk(rp(1:num_non_nan))=false;

        for num_glms=-1:3
            opt=struct();

            if num_glms==-1
                opt.regress_dsm=randn(npairs,1);
                opt.regress_dsm(nan_msk)=NaN;
            end

            if num_glms<=0
                opt.target_dsm=randn(npairs,1);
                opt.target_dsm(nan_msk)=NaN;
            else
                opt.glm_dsm=cell(num_glms,1);
                for k=1:num_glms
                    opt.glm_dsm{k}=randn(npairs,1);
                    opt.glm_dsm{k}(nan_msk)=NaN;
                end
            end

            for set_inconsistent_non_nan_msk=[false,true]
                key_cell=intersect(fieldnames(opt),...
                                    {'glm_dsm','regress_dsm'});

                if set_inconsistent_non_nan_msk && ...
                            (isempty(key_cell) || any(num_glms==[0,1]))
                    % skip
                    continue;
                end

                if set_inconsistent_non_nan_msk
                    assert(numel(key_cell)==1);
                    key=key_cell{1};

                    % swap true and false value in one of the matrices

                    value=opt.(key);

                    value_is_cell=iscell(value);
                    if value_is_cell
                        glm_idx=num_glms;
                        value=value{glm_idx};
                    end

                    i=find(isnan(value),1);
                    j=find(~isnan(value),1);

                    value(i)=value(j);
                    value(j)=NaN;

                    if value_is_cell
                        opt.(key){glm_idx}=value;
                    else
                        opt.(key)=value;
                    end
                end

                func_handle=@()cosmo_target_dsm_corr_measure(ds,opt);

                expect_error=set_inconsistent_non_nan_msk;
                if expect_error
                    assertExceptionThrown(func_handle,'');
                else
                    % should be ok
                    func_handle();
                end
            end
        end
    end



% test exceptions
function test_target_dsm_corr_measure_exceptions
    ds=cosmo_synthetic_dataset('ntargets',6,'nchunks',1);
    mat1=[1 2 3 4 2 3 2 1 2 3 1 2 3 2 2];

    aet=@(varargin)assertExceptionThrown(...
                @()cosmo_target_dsm_corr_measure(varargin{:}),'');
    aet(struct,mat1);
    aet(ds);
    aet(ds,'target_dsm',[mat1 1]);
    aet(ds,'target_dsm',eye(6));
    aet(ds,'target_dsm',zeros(7));

    aet(ds,'target_dsm',mat1,'glm_dsm',{mat1});
    aet(ds,'regress_dsm',mat1,'glm_dsm',{mat1});
    aet(ds,'target_dsm',mat1,'glm_dsm',repmat({mat1},15,1));
    aet(ds,'regress_dsm',mat1,'glm_dsm',repmat({mat1},15,1));
    aet(ds,'glm_dsm',struct());
    aet(ds,'glm_dsm',{[mat1 1]});

    mat2_ds=struct();
    mat2_ds.samples=mat1;

    % requires numeric input
    mat2_ds_rep=repmat({mat2_ds},1,15);
    mat2_ds_stacked=cat(1,mat2_ds_rep{:});
    aet(ds,'target_dsm',mat2_ds_stacked);

    % illegal correlation type
    aet(ds,'target_dsm',mat1,'type','foo');
    aet(ds,'target_dsm',mat1,'type',2);

    % Spearman or Kendall not allowed when using glm_dsm
    aet(ds,'glm_dsm',mat1,'type','Spearman');
    aet(ds,'glm_dsm',mat1,'type','Kendall');

    % Kendall not allowed with regress_dsm
    aet(ds,'target_dsm',mat1,'regress_dsm',{mat1},'type','Kendall');


function test_target_dsm_corr_measure_warnings_zero()
    ds=cosmo_synthetic_dataset('ntargets',3,'nchunks',1);

    ds_zero=ds;
    ds_zero.samples(:)=0;

    helper_target_dsm_corr_measure_with_warning(ds_zero,...
                                        'target_dsm',[1 2 3]);

function test_target_dsm_corr_measure_warnings_nan()
    ds=cosmo_synthetic_dataset('ntargets',3,'nchunks',1);

    ds_zero=ds;
    ds_zero.samples(1)=NaN;

    helper_target_dsm_corr_measure_with_warning(ds_zero,...
                                        'target_dsm',[1 2 3]);

function test_target_dsm_corr_measure_warnings_constant()
    ds=cosmo_synthetic_dataset('ntargets',3,'nchunks',1);

    ds_zero=ds;

    helper_target_dsm_corr_measure_with_warning(ds_zero,...
                                        'target_dsm',[0 0 0]);

function result=helper_target_dsm_corr_measure_with_warning(varargin)
    % ensure to reset to original state when leaving this function
    warning_state=cosmo_warning();
    cleaner=onCleanup(@()cosmo_warning(warning_state));

    % clear all warnings
    empty_state=warning_state;
    empty_state.shown_warnings={};
    cosmo_warning(empty_state);
    cosmo_warning('off');

    result=cosmo_target_dsm_corr_measure(varargin{:});
    w=cosmo_warning();
    assert(numel(w.shown_warnings)>0)
    assert(iscellstr(w.shown_warnings));