function ds_sa=cosmo_correlation_measure(ds, varargin)
% Computes a split-half correlation measure
%
% d=cosmo_correlation_measure(ds[, args])
%
% Inputs:
% ds dataset structure with fields .samples, .sa.targets and
% .sa.chunks
% args optional struct with the following optional fields:
% .partitions struct with fields .train_indices and .test_indices.
% Both should be a Px1 cell for P partitions. If omitted,
% it is set to cosmo_nchoosek_partitioner(ds,'half').
% .template QxQ matrix for Q classes in each chunk. This matrix
% weights the correlations across the two halves.
% If ds.sa.targets has only one unique value, it must be
% set to the scalar value 1; otherwise it should
% have a mean of zero. If omitted, it has positive values
% of (1/Q) on the diagonal and (-1/(Q*(Q-1)) off the
% diagonal.
% (Note: this can be used to test for representational
% similarity matching)
% .merge_func A function handle used to merge data from matching
% targets in the same chunk. Default is @(x) mean(x,1),
% meaning that values are averaged over the same samples.
% It is assumed that isequal(args.merge_func(y),y) if y
% is a row vector.
% .corr_type Type of correlation: 'Pearson','Spearman','Kendall'.
% The default is 'Pearson'.
% .post_corr_func Operation performed after correlation. (default:
% @atanh)
% .output 'mean' (default): correlations weighted by template
% 'raw' or 'correlation': correlations between all classes
% 'one_minus_correlation': 1 minus correlations
% 'mean_by_fold': provide weighted correlations for each
% fold in the partitions.
%
%
% Output:
% ds_sa Struct with fields:
% .samples Scalar indicating how well the template matrix
% correlates with the correlation matrix from the two
% halves (averaged over partitions). By default:
% - this value is based on Fisher-transformed correlation
% values, not raw correlation values
% - this is the average of the (Fisher-transformed)
% on-diagonal minus the average of the
% (Fisher-transformed) off-diagonal elements of the
% correlation matrix based on the two halves of the data.
% .sa Struct with field:
% .labels if output=='corr'
% .half1 } if output=='raw': (N^2)x1 vectors with indices of data
% .half2 } from two halves, with N the number of unique targets.
%
% Example:
% ds=cosmo_synthetic_dataset();
% %
% % compute on-minus-off diagonal correlations
% c=cosmo_correlation_measure(ds);
% cosmo_disp(c)
% %|| .samples
% %|| [ 1.23 ]
% %|| .sa
% %|| .labels
% %|| { 'corr' }
% %
% % Spearman correlation requires the Matlab statistics toolbox
% cosmo_skip_test_if_no_external('@stats');
% %
% c=cosmo_correlation_measure(ds,'corr_type','Spearman');
% cosmo_disp(c)
% %|| .samples
% %|| [ 1.28 ]
% %|| .sa
% %|| .labels
% %|| { 'corr' }
%
% ds=cosmo_synthetic_dataset();
% % get raw correlations without fisher transformation
% c_raw=cosmo_correlation_measure(ds,'output','correlation',...
% 'post_corr_func',[]);
% cosmo_disp(c_raw)
% %|| .samples
% %|| [ 0.363
% %|| -0.404
% %|| -0.447
% %|| 0.606 ]
% %|| .sa
% %|| .half1
% %|| [ 1
% %|| 2
% %|| 1
% %|| 2 ]
% %|| .half2
% %|| [ 1
% %|| 1
% %|| 2
% %|| 2 ]
% %|| .a
% %|| .sdim
% %|| .labels
% %|| { 'half1' 'half2' }
% %|| .values
% %|| { [ 1 [ 1
% %|| 2 ] 2 ] }
% %
% % convert to matrix form (N x N x P, with N the number of classes and
% % P=1)
% matrices=cosmo_unflatten(c_raw,1);
% cosmo_disp(matrices)
% %|| [ 0.363 -0.447
% %|| -0.404 0.606 ]
%
% % compute for each fold separately, using a custom take-one-chunk
% % out partitioning scheme and without Fisher transformation of
% % the correlations.
% % Note that c.sa.chunks in the output reflects the test chunk
% % in each partition
% ds=cosmo_synthetic_dataset('type','fmri','nchunks',4);
% partitions=cosmo_nfold_partitioner(ds);
% c=cosmo_correlation_measure(ds,'output','mean_by_fold',...
% 'partitions',partitions,...
% 'post_corr_func',[]);
% cosmo_disp(c.samples);
% %|| [ 1.32
% %|| 0.512
% %|| 1.05
% %|| 1.23 ]
% cosmo_disp(c.sa);
% %|| .partition
% %|| [ 1
% %|| 2
% %|| 3
% %|| 4 ]
%
% % minimal searchlight example
% ds=cosmo_synthetic_dataset('type','fmri');
% % use searchlight with radius 1 voxel (radius=3 is more typical)
% nbrhood=cosmo_spherical_neighborhood(ds,'radius',1,'progress',false);
% % run searchlight
% res=cosmo_searchlight(ds,nbrhood,@cosmo_correlation_measure,...
% 'progress',false);
% cosmo_disp(res.samples)
% %|| [ 1.87 1.25 1.51 1.68 1.71 0.879 ]
% cosmo_disp(res.sa)
% %|| .labels
% %|| { 'corr' }
%
% Notes:
% - by default the post_corr_func is set to @atanh. This is equivalent to
% a Fisher transformation from r (correlation) to z (standard z-score).
% The underlying math is z=atanh(r)=.5*log((1+r)./log(1-r)).
% The rationale is to make data more normally distributed under the
% null hypothesis.
% Fisher-transformed correlations can be transformed back to
% their original correlation values using 'tanh', which is the inverse
% of 'atanh'.
% - To disable the (by default used) Fisher-transformation, set the
% 'post_corr_func' option to [].
% - if multiple samples are present with the same chunk and target, they
% are averaged *prior* to computing the correlations.
% - if multiple partitions are present, then the correlations are
% computed separately for each partition, and then averaged (unless
% the 'output' option is set, and set to a different value than
% 'mean'.
% - When more than two chunks are present in the input, partitions
% consist of all possible half splits for which the number of unique
% chunks in the train and test set differ by 1 at most.
% For illustration, up to 6 chunks, the
% partitions are:
% - 2 chunks - partition #1
% chunks first half: { 1
% chunks second half: { 2
%
% - 3 chunks - partition #1 #2 #3
% chunks first half: { 3 2 1
% { 1 1 2
% chunks second half: { 2 3 3
%
% - 4 chunks - partition #1 #2 #3
% { 3 2 2
% chunks first half: { 4 4 3
% { 1 1 1
% chunks second half: { 2 3 4
%
% - 5 chunks - partition #1 #2 #3 #4 #5 #6 #7 #8 #9 #10
% { 4 3 3 2 2 2 1 1 1 1
% chunks first half: { 5 5 4 5 4 3 5 4 3 2
% { 1 1 1 1 1 1 2 2 2 3
% chunks second half: { 2 2 2 3 3 4 3 3 4 4
% { 3 4 5 4 5 5 4 5 5 5
%
% - 6 chunks - partition #1 #2 #3 #4 #5 #6 #7 #8 #9 #10
% { 4 3 3 3 2 2 2 2 2 2
% chunks first half: { 5 5 4 4 5 4 4 3 3 3
% { 6 6 6 5 6 6 5 6 5 4
% { 1 1 1 1 1 1 1 1 1 1
% chunks second half: { 2 2 2 2 3 3 3 4 4 5
% { 3 4 5 6 4 5 6 5 6 6
% Thus, with an increasing number of chunks, the number of partitions
% (and thus the time required to run this function) increases
% quadratically. To use simpler partition schemes (e.g. odd-even, as
% provided by cosmo_oddeven_partitioner), specify the 'partitions'
% argument.
%
% References
% - Haxby, J. V. et al (2001). Distributed and overlapping
% representations of faces and objects in ventral temporal cortex.
% Science 293, 2425?2430
%
% # For CoSMoMVPA's copyright information and license terms, #
% # see the COPYING file distributed with CoSMoMVPA. #
persistent cached_partitions;
persistent cached_chunks;
persistent cached_params;
persistent cached_varargin;
% optimize parameters parsing: if same arguments were used in
% a previous call, do not use cosmo_structjoin (which is relative
% expensive)
if ~isempty(cached_params) && isequal(cached_varargin, varargin)
params=cached_params;
else
defaults=struct();
defaults.partitions=[];
defaults.template=[];
defaults.merge_func=[];
defaults.corr_type='Pearson';
defaults.post_corr_func=@atanh;
defaults.output='mean';
defaults.check_partitions=true;
params=cosmo_structjoin(defaults, varargin);
show_warning_if_no_defaults(defaults,params);
cached_params=params;
cached_varargin=varargin;
end
partitions=params.partitions;
template=params.template;
merge_func=params.merge_func;
post_corr_func=params.post_corr_func;
check_partitions=params.check_partitions;
chunks=ds.sa.chunks;
if isempty(partitions)
if ~isempty(cached_chunks) && isequal(cached_chunks, chunks)
partitions=cached_partitions;
% assume that partitions were already checked
check_partitions=false;
else
partitions=cosmo_nchoosek_partitioner(ds,'half');
cached_chunks=chunks;
cached_partitions=partitions;
end
end
if check_partitions
cosmo_check_partitions(partitions, ds, params);
end
targets=ds.sa.targets;
nsamples=size(targets,1);
[classes,unused,class_ids]=fast_unique(targets);
nclasses=numel(classes);
switch nclasses
case 0
error('No classes found - this is not supported');
case 1
if ~isequal(template,1)
error([ 'Only one unique value for .sa.targets was found; '...
'this is only allowed when the ''template'' '...
'parameter is explicitly set to 1.\n'...
'Note that this option does not compute the '...
'correlation *difference* between matching and '...
'non-matching values in .sa.targets; instead it '...
'computes the direct correlation '...
'between two halves of the data. A typical use '...
'case is when .samples already contains a '...
'difference score comparing two different '...
'conditions']);
end
otherwise
if isempty(template)
template=(eye(nclasses)-1/nclasses)/(nclasses-1);
else
max_tolerance=1e-8;
if abs(sum(template(:)))>max_tolerance
error('Template matrix does not have a sum of zero');
end
expected_size=[nclasses,nclasses];
if ~isequal(size(template),[nclasses,nclasses])
error('Template must be of size %dx%d',expected_size);
end
end
end
template_msk=isfinite(template);
npartitions=numel(partitions.train_indices);
% space for output
pdata=cell(npartitions,1);
% keep track of how often each chunk was used in test_indices,
% and which chunk was used last
test_chunks_last=NaN(nsamples,1);
test_chunks_count=zeros(nsamples,1);
for k=1:npartitions
train_indices=partitions.train_indices{k};
test_indices=partitions.test_indices{k};
% get data in each half
half1=get_data(ds, train_indices, class_ids, merge_func);
half2=get_data(ds, test_indices, class_ids, merge_func);
% compute raw correlations
raw_c=cosmo_corr(half1', half2', params.corr_type);
% apply post-processing (usually Fisher-transform, i.e. atanh)
c=apply_post_corr_func(post_corr_func,raw_c);
% aggregate results
pdata{k}=aggregate_correlations(c,template,template_msk,params.output);
test_chunks_last(test_indices)=chunks(test_indices);
test_chunks_count(test_indices)=test_chunks_count(test_indices)+1;
end
ds_sa=struct();
switch params.output
case 'mean'
ds_sa.samples=mean(cat(2,pdata{:}),2);
ds_sa.sa.labels={'corr'};
case {'raw','correlation'}
ds_sa.samples=mean(cat(2,pdata{:}),2);
nclasses=numel(classes);
ds_sa.sa.half1=reshape(repmat((1:nclasses)',nclasses,1),[],1);
ds_sa.sa.half2=reshape(repmat((1:nclasses),nclasses,1),[],1);
ds_sa.a.sdim=struct();
ds_sa.a.sdim.labels={'half1','half2'};
ds_sa.a.sdim.values={classes, classes};
case 'mean_by_fold'
ds_sa.sa.partition=(1:npartitions)';
ds_sa.samples=[pdata{:}]';
otherwise
assert(false,'this should be caught by get_data');
end
function c=apply_post_corr_func(post_corr_func,c)
if ~isempty(post_corr_func)
c=post_corr_func(c);
end
function agg_c=aggregate_correlations(c,template,template_msk,output)
switch output
case {'mean','mean_by_fold'}
pcw=c(template_msk).*template(template_msk);
agg_c=sum(pcw(:));
case {'raw','correlation'}
agg_c=c(:);
case 'one_minus_correlation'
error(['the ''output'' option ''one_minus_correlation'' '...
'has been removed. Please contact CoSMoMVPA''s '...
'authors if you really need this option']);
otherwise
error('Unsupported output %s',output);
end
function [unq,pos,ids]=fast_unique(x)
% optimized for vectors
[y,i]=sort(x,1);
msk=[true;diff(y)>0];
j=find(msk);
pos=i(j);
unq=y(j);
vs=cumsum(msk);
ids=vs;
ids(i)=vs;
function data=get_data(ds, sample_idxs, class_ids, merge_func)
samples=ds.samples(sample_idxs,:);
target_ids=class_ids(sample_idxs);
nclasses=max(class_ids);
merge_by_averaging=isempty(merge_func);
if isequal(target_ids',1:nclasses) && merge_by_averaging
% optimize standard case of one sample per class and normal
% averaging over samples
data=samples;
return
end
nfeatures=size(samples,2);
data=zeros(nclasses,nfeatures);
for k=1:nclasses
msk=target_ids==k;
n=sum(msk);
class_samples=samples(msk,:);
if merge_by_averaging
if n==1
data(k,:)=class_samples;
else
data(k,:)=sum(class_samples,1)/n;
end
else
data(k,:)=merge_func(class_samples);
end
if n==0
error('missing target class %d', class_ids(k));
end
end
function show_warning_if_no_defaults(defaults,params)
keys={'output','post_corr_func'};
has_defaults=isequal(select_fields(defaults,keys),...
select_fields(params,keys));
if ~has_defaults && isequal(params.post_corr_func,@atanh)
name=mfilename();
msg=sprintf(...
['Please note that the ''%s'' function applies '...
'Fisher transformation after the correlations '...
'have been computed. This was a somewhat unfortunate '...
'implementation decision that will not be changed '...
'to avoid breaking behaviour with earlier versions.\n'...
'To disable using the Fisher transformation, use the '...
'''%s'' function while setting the ''post_corr_func'''...
'option to the empty array ([])'],name,name);
cosmo_warning(msg);
end
function subset=select_fields(s, keys)
n=numel(keys);
for k=1:n
key=keys{k};
subset.(key)=s.(key);
end