cosmo correlation measure

function ds_sa=cosmo_correlation_measure(ds, varargin)
% Computes a split-half correlation measure
%
% d=cosmo_correlation_measure(ds[, args])
%
% Inputs:
%  ds             dataset structure with fields .samples, .sa.targets and
%                 .sa.chunks
%  args           optional struct with the following optional fields:
%    .partitions  struct with fields .train_indices and .test_indices.
%                 Both should be a Px1 cell for P partitions. If omitted,
%                 it is set to cosmo_nchoosek_partitioner(ds,'half').
%    .template    QxQ matrix for Q classes in each chunk. This matrix
%                 weights the correlations across the two halves.
%                 If ds.sa.targets has only one unique value, it must be
%                 set to the scalar value 1; otherwise it should
%                 have a mean of zero. If omitted, it has positive values
%                 of (1/Q) on the diagonal and (-1/(Q*(Q-1)) off the
%                 diagonal.
%                 (Note: this can be used to test for representational
%                 similarity matching)
%    .merge_func  A function handle used to merge data from matching
%                 targets in the same chunk. Default is @(x) mean(x,1),
%                 meaning that values are averaged over the same samples.
%                 It is assumed that isequal(args.merge_func(y),y) if y
%                 is a row vector.
%    .corr_type   Type of correlation: 'Pearson','Spearman','Kendall'.
%                 The default is 'Pearson'.
%    .post_corr_func  Operation performed after correlation. (default:
%                     @atanh)
%    .output      'mean' (default): correlations weighted by template
%                 'raw' or 'correlation': correlations between all classes
%                 'one_minus_correlation': 1 minus correlations
%                 'mean_by_fold': provide weighted correlations for each
%                                 fold in the partitions.
%
%
% Output:
%    ds_sa        Struct with fields:
%      .samples   Scalar indicating how well the template matrix
%                 correlates with the correlation matrix from the two
%                 halves (averaged over partitions). By default:
%                 - this value is based on Fisher-transformed correlation
%                   values, not raw correlation values
%                 - this is the average of the (Fisher-transformed)
%                   on-diagonal minus the average of the
%                   (Fisher-transformed) off-diagonal elements of the
%                   correlation matrix based on the two halves of the data.
%      .sa        Struct with field:
%        .labels  if output=='corr'
%        .half1   } if output=='raw': (N^2)x1 vectors with indices of data
%        .half2   } from two halves, with N the number of unique targets.
%
% Example:
%     ds=cosmo_synthetic_dataset();
%     %
%     % compute on-minus-off diagonal correlations
%     c=cosmo_correlation_measure(ds);
%     cosmo_disp(c)
%     %|| .samples
%     %||   [ 1.23 ]
%     %|| .sa
%     %||   .labels
%     %||   { 'corr' }
%     %
%     % Spearman correlation requires the Matlab statistics toolbox
%     cosmo_skip_test_if_no_external('@stats');
%     %
%     c=cosmo_correlation_measure(ds,'corr_type','Spearman');
%     cosmo_disp(c)
%     %|| .samples
%     %||   [ 1.28 ]
%     %|| .sa
%     %||   .labels
%     %||   { 'corr' }
%
%     ds=cosmo_synthetic_dataset();
%     % get raw correlations without fisher transformation
%     c_raw=cosmo_correlation_measure(ds,'output','correlation',...
%                                       'post_corr_func',[]);
%     cosmo_disp(c_raw)
%     %|| .samples
%     %||   [  0.363
%     %||     -0.404
%     %||     -0.447
%     %||      0.606 ]
%     %|| .sa
%     %||   .half1
%     %||     [ 1
%     %||       2
%     %||       1
%     %||       2 ]
%     %||   .half2
%     %||     [ 1
%     %||       1
%     %||       2
%     %||       2 ]
%     %|| .a
%     %||   .sdim
%     %||     .labels
%     %||       { 'half1'  'half2' }
%     %||     .values
%     %||       { [ 1    [ 1
%     %||           2 ]    2 ] }
%     %
%     % convert to matrix form (N x N x P, with N the number of classes and
%     % P=1)
%     matrices=cosmo_unflatten(c_raw,1);
%     cosmo_disp(matrices)
%     %|| [  0.363    -0.447
%     %||   -0.404     0.606 ]
%
%     % compute for each fold separately, using a custom take-one-chunk
%     % out partitioning scheme and without Fisher transformation of
%     % the correlations.
%     % Note that c.sa.chunks in the output reflects the test chunk
%     % in each partition
%     ds=cosmo_synthetic_dataset('type','fmri','nchunks',4);
%     partitions=cosmo_nfold_partitioner(ds);
%     c=cosmo_correlation_measure(ds,'output','mean_by_fold',...
%                                    'partitions',partitions,...
%                                    'post_corr_func',[]);
%     cosmo_disp(c.samples);
%     %|| [  1.32
%     %||   0.512
%     %||    1.05
%     %||    1.23 ]
%     cosmo_disp(c.sa);
%     %|| .partition
%     %||   [ 1
%     %||     2
%     %||     3
%     %||     4 ]
%
%     % minimal searchlight example
%     ds=cosmo_synthetic_dataset('type','fmri');
%     % use searchlight with radius 1 voxel (radius=3 is more typical)
%     nbrhood=cosmo_spherical_neighborhood(ds,'radius',1,'progress',false);
%     % run searchlight
%     res=cosmo_searchlight(ds,nbrhood,@cosmo_correlation_measure,...
%                               'progress',false);
%     cosmo_disp(res.samples)
%     %|| [ 1.87      1.25      1.51      1.68      1.71     0.879 ]
%     cosmo_disp(res.sa)
%     %|| .labels
%     %||   { 'corr' }
%
% Notes:
%   - by default the post_corr_func is set to @atanh. This is equivalent to
%     a Fisher transformation from r (correlation) to z (standard z-score).
%     The underlying math is z=atanh(r)=.5*log((1+r)./log(1-r)).
%     The rationale is to make data more normally distributed under the
%     null hypothesis.
%     Fisher-transformed correlations can be transformed back to
%     their original correlation values using 'tanh', which is the inverse
%     of 'atanh'.
%   - To disable the (by default used) Fisher-transformation, set the
%     'post_corr_func' option to [].
%   - if multiple samples are present with the same chunk and target, they
%     are averaged *prior* to computing the correlations.
%   - if multiple partitions are present, then the correlations are
%     computed separately for each partition, and then averaged (unless
%     the 'output' option is set, and set to a different value than
%     'mean'.
%   - When more than two chunks are present in the input, partitions
%     consist of all possible half splits for which the number of unique
%     chunks in the train and test set differ by 1 at most.
%     For illustration, up to 6 chunks, the
%     partitions are:
%       - 2 chunks   -  partition #1
%         chunks first  half: {    1
%         chunks second half:  {   2
%
%       - 3 chunks   -  partition #1  #2  #3
%         chunks first  half: {    3   2   1
%                              {   1   1   2
%         chunks second half:  {   2   3   3
%
%       - 4 chunks   -  partition #1  #2  #3
%                             {    3   2   2
%         chunks first  half: {    4   4   3
%                              {   1   1   1
%         chunks second half:  {   2   3   4
%
%       - 5 chunks   -  partition #1  #2  #3  #4  #5  #6  #7  #8  #9  #10
%                             {    4   3   3   2   2   2   1   1   1   1
%         chunks first  half: {    5   5   4   5   4   3   5   4   3   2
%                              {   1   1   1   1   1   1   2   2   2   3
%         chunks second half:  {   2   2   2   3   3   4   3   3   4   4
%                              {   3   4   5   4   5   5   4   5   5   5
%
%       - 6 chunks   -  partition #1  #2  #3  #4  #5  #6  #7  #8  #9  #10
%                             {    4   3   3   3   2   2   2   2   2   2
%         chunks first  half: {    5   5   4   4   5   4   4   3   3   3
%                             {    6   6   6   5   6   6   5   6   5   4
%                              {   1   1   1   1   1   1   1   1   1   1
%         chunks second half:  {   2   2   2   2   3   3   3   4   4   5
%                              {   3   4   5   6   4   5   6   5   6   6
%     Thus, with an increasing number of chunks, the number of partitions
%     (and thus the time required to run this function) increases
%     quadratically. To use simpler partition schemes (e.g. odd-even, as
%     provided by cosmo_oddeven_partitioner), specify the 'partitions'
%     argument.
%
% References
%   - Haxby, J. V. et al (2001). Distributed and overlapping
%     representations of faces and objects in ventral temporal cortex.
%     Science 293, 2425?2430
%
% #   For CoSMoMVPA's copyright information and license terms,   #
% #   see the COPYING file distributed with CoSMoMVPA.           #

persistent cached_partitions;
persistent cached_chunks;
persistent cached_params;
persistent cached_varargin;

% optimize parameters parsing: if same arguments were used in
% a previous call, do not use cosmo_structjoin (which is relative
% expensive)
if ~isempty(cached_params) && isequal(cached_varargin, varargin)
    params=cached_params;
else
    defaults=struct();
    defaults.partitions=[];
    defaults.template=[];
    defaults.merge_func=[];
    defaults.corr_type='Pearson';
    defaults.post_corr_func=@atanh;
    defaults.output='mean';
    defaults.check_partitions=true;

    params=cosmo_structjoin(defaults, varargin);

    show_warning_if_no_defaults(defaults,params);

    cached_params=params;
    cached_varargin=varargin;
end

partitions=params.partitions;
template=params.template;
merge_func=params.merge_func;
post_corr_func=params.post_corr_func;
check_partitions=params.check_partitions;

chunks=ds.sa.chunks;

if isempty(partitions)
    if ~isempty(cached_chunks) && isequal(cached_chunks, chunks)
        partitions=cached_partitions;

        % assume that partitions were already checked
        check_partitions=false;
    else
        partitions=cosmo_nchoosek_partitioner(ds,'half');
        cached_chunks=chunks;
        cached_partitions=partitions;
    end
end

if check_partitions
    cosmo_check_partitions(partitions, ds, params);
end

targets=ds.sa.targets;
nsamples=size(targets,1);

[classes,unused,class_ids]=fast_unique(targets);
nclasses=numel(classes);

switch nclasses
    case 0
        error('No classes found - this is not supported');
    case 1
        if ~isequal(template,1)
            error([ 'Only one unique value for .sa.targets was found; '...
                    'this is only allowed when the ''template'' '...
                    'parameter is explicitly set to 1.\n'...
                    'Note that this option does not compute the '...
                    'correlation *difference* between matching and '...
                    'non-matching values in .sa.targets; instead it '...
                    'computes the direct correlation '...
                    'between two halves of the data. A typical use '...
                    'case is when .samples already contains a '...
                    'difference score  comparing two different '...
                    'conditions']);
        end
    otherwise
        if isempty(template)
            template=(eye(nclasses)-1/nclasses)/(nclasses-1);
        else
            max_tolerance=1e-8;
            if abs(sum(template(:)))>max_tolerance
                error('Template matrix does not have a sum of zero');
            end
            expected_size=[nclasses,nclasses];
            if ~isequal(size(template),[nclasses,nclasses])
                error('Template must be of size %dx%d',expected_size);
            end
        end
end


template_msk=isfinite(template);

npartitions=numel(partitions.train_indices);

% space for output
pdata=cell(npartitions,1);

% keep track of how often each chunk was used in test_indices,
% and which chunk was used last
test_chunks_last=NaN(nsamples,1);
test_chunks_count=zeros(nsamples,1);

for k=1:npartitions
    train_indices=partitions.train_indices{k};
    test_indices=partitions.test_indices{k};

    % get data in each half
    half1=get_data(ds, train_indices, class_ids, merge_func);
    half2=get_data(ds, test_indices, class_ids, merge_func);

    % compute raw correlations
    raw_c=cosmo_corr(half1', half2', params.corr_type);

    % apply post-processing (usually Fisher-transform, i.e. atanh)
    c=apply_post_corr_func(post_corr_func,raw_c);

    % aggregate results
    pdata{k}=aggregate_correlations(c,template,template_msk,params.output);
    test_chunks_last(test_indices)=chunks(test_indices);
    test_chunks_count(test_indices)=test_chunks_count(test_indices)+1;
end

ds_sa=struct();

switch params.output
    case 'mean'
        ds_sa.samples=mean(cat(2,pdata{:}),2);
        ds_sa.sa.labels={'corr'};
    case {'raw','correlation'}
        ds_sa.samples=mean(cat(2,pdata{:}),2);

        nclasses=numel(classes);
        ds_sa.sa.half1=reshape(repmat((1:nclasses)',nclasses,1),[],1);
        ds_sa.sa.half2=reshape(repmat((1:nclasses),nclasses,1),[],1);

        ds_sa.a.sdim=struct();
        ds_sa.a.sdim.labels={'half1','half2'};
        ds_sa.a.sdim.values={classes, classes};

    case 'mean_by_fold'
        ds_sa.sa.partition=(1:npartitions)';
        ds_sa.samples=[pdata{:}]';

    otherwise
        assert(false,'this should be caught by get_data');
end

function c=apply_post_corr_func(post_corr_func,c)
    if ~isempty(post_corr_func)
        c=post_corr_func(c);
    end


function agg_c=aggregate_correlations(c,template,template_msk,output)
    switch output
        case {'mean','mean_by_fold'}
            pcw=c(template_msk).*template(template_msk);
            agg_c=sum(pcw(:));
        case {'raw','correlation'}
            agg_c=c(:);
        case 'one_minus_correlation'
            error(['the ''output'' option ''one_minus_correlation'' '...
                        'has been removed. Please contact CoSMoMVPA''s '...
                        'authors if you really need this option']);
        otherwise
            error('Unsupported output %s',output);
    end

function [unq,pos,ids]=fast_unique(x)
    % optimized for vectors
    [y,i]=sort(x,1);
    msk=[true;diff(y)>0];
    j=find(msk);
    pos=i(j);
    unq=y(j);

    vs=cumsum(msk);
    ids=vs;
    ids(i)=vs;


function data=get_data(ds, sample_idxs, class_ids, merge_func)
    samples=ds.samples(sample_idxs,:);
    target_ids=class_ids(sample_idxs);

    nclasses=max(class_ids);

    merge_by_averaging=isempty(merge_func);
    if isequal(target_ids',1:nclasses) && merge_by_averaging
        % optimize standard case of one sample per class and normal
        % averaging over samples
        data=samples;
        return
    end

    nfeatures=size(samples,2);
    data=zeros(nclasses,nfeatures);

    for k=1:nclasses
        msk=target_ids==k;

        n=sum(msk);

        class_samples=samples(msk,:);

        if merge_by_averaging
            if n==1
                data(k,:)=class_samples;
            else
                data(k,:)=sum(class_samples,1)/n;
            end
        else
            data(k,:)=merge_func(class_samples);
        end

        if n==0
            error('missing target class %d', class_ids(k));
        end
    end


function show_warning_if_no_defaults(defaults,params)
    keys={'output','post_corr_func'};
    has_defaults=isequal(select_fields(defaults,keys),...
                         select_fields(params,keys));

    if ~has_defaults && isequal(params.post_corr_func,@atanh)
        name=mfilename();
        msg=sprintf(...
                ['Please note that the ''%s'' function applies '...
                'Fisher transformation after the correlations '...
                'have been computed. This was a somewhat unfortunate '...
                'implementation decision that will not be changed '...
                'to avoid breaking behaviour with earlier versions.\n'...
                'To disable using the Fisher transformation, use the '...
                '''%s'' function while setting the ''post_corr_func'''...
                'option to the empty array ([])'],name,name);
        cosmo_warning(msg);
    end

function subset=select_fields(s, keys)
    n=numel(keys);
    for k=1:n
        key=keys{k};
        subset.(key)=s.(key);
    end