cosmo target dsm corr measure skl

function ds_sa = cosmo_target_dsm_corr_measure(ds, varargin)
    % measure correlation with target dissimilarity matrix
    %
    % ds_sa = cosmo_target_dsm_corr_measure(dataset, args)
    %
    % Inputs:
    %   ds             dataset struct with field .samples PxQ for P samples and
    %                  Q features
    %   args           struct with fields:
    %     .target_dsm  (optional) Either:
    %                  - target dissimilarity matrix of size PxP. It should
    %                    have zeros on the diagonal and be symmetric.
    %                  - target dissimilarity vector of size Nx1, with
    %                    N=P*(P-1)/2 the number of pairs of samples in ds.
    %                  This option is mutually exclusive with the 'glm_dsm'
    %                  option
    %     .metric      (optional) distance metric used in pdist to compute
    %                  pair-wise distances between samples in ds. It accepts
    %                  any metric supported by pdist (default: 'correlation')
    %     .type        (optional) type of correlation between target_dsm and
    %                  ds, one of 'Pearson' (default), 'Spearman', or
    %                  'Kendall'. If the 'regress_dsm' option is used, then the
    %                  specified correlation type is used for the partial
    %                  correlaton computation, and a Pearson correlation
    %                  is returned on the residuals (that is, the
    %                  remaining dissimilarities after controlling for those
    %                  in regress_dsm).
    %     .regress_dsm (optional) target dissimilarity matrix or vector (as
    %                  .target_dsm), or a cell with matrices or vectors, that
    %                  should be regressed out. If this option is provided then
    %                  the output is the partial (Pearson) correlation between
    %                  the pairwise distances between samples in ds and
    %                  target_dsm, after controlling for the effect of the
    %                  matrix (or matrices) in regress_dsm. (Using this option
    %                  yields similar behaviour as the Matlab function
    %                  'partial_corr').
    %                  When using this option, the 'type' option cannot be
    %                  set to 'Kendall'.
    %     .glm_dsm     (optional) cell with model dissimilarity matrices or
    %                  vectors (as .target_dsm) for using a general linear
    %                  model to get regression coefficients for each element in
    %                  .glm_dsm. Both the input data and the dissimilarity
    %                  matrices are z-scored before estimating the regression
    %                  coefficients.
    %                  This option is required when 'target_dsm' is not
    %                  provided; it cannot cannot used together with the
    %                  'target_dsm' or 'regress_dsm' options.
    %                  When using this option, the 'type' option cannot be
    %                  set to 'Spearman' or 'Kendall'.
    %                  For this option, the output has as many rows (samples)
    %                  as there are elements (dissimilarity matrices) in
    %                  .glm_dsm.
    %     .center_data If set to true, then the mean of each feature (column in
    %                  ds.samples) is subtracted from each column prior to
    %                  computing the pairwise distances for all samples in ds.
    %                  This is generally recommended but is not the default in
    %                  order to avoid breaking behavaiour from earlier
    %                  versions. For a rationale why this is recommendable, see
    %                  the Diedrichsen & Kriegeskorte article (below in
    %                  references)
    %                  Default: false
    %
    % Output:
    %    ds_sa         Dataset struct with fields:
    %      .samples    Scalar correlation value between the pair-wise
    %                  distances of the samples in ds and target_dsm; or
    %                  (when 'glm_dsms' is supplied) a column vector with
    %                  normalized beta coefficients. These values
    %                  are untransformed (e.g. there is no Fisher transform).
    %      .sa         Struct with field:
    %        .labels   {'rho'}; or (when 'glm_dsm' is supplied) a cell
    %                  {'beta1','beta2',...}.
    %
    % Examples:
    %     % generate synthetic dataset with 6 classes (conditions),
    %     % one sample per class
    %     ds=cosmo_synthetic_dataset('ntargets',6,'nchunks',1);
    %     %
    %     % create target dissimilarity matrix to test whether
    %     % - class 1 and 2 are similar (and different from classes 3-6)
    %     % - class 3 and 4 are similar (and different from classes 1,2,5,6)
    %     % - class 5 and 6 are similar (and different from classes 1-4)
    %     target_dsm=1-kron(eye(3),ones(2));
    %     %
    %     % show the target dissimilarity matrix
    %     cosmo_disp(target_dsm);
    %     %|| [ 0         0         1         1         1         1
    %     %||   0         0         1         1         1         1
    %     %||   1         1         0         0         1         1
    %     %||   1         1         0         0         1         1
    %     %||   1         1         1         1         0         0
    %     %||   1         1         1         1         0         0 ]
    %     %
    %     % compute similarity between pairw-wise similarity of the
    %     % patterns in the dataset and the target dissimilarity matrix
    %     dcm_ds=cosmo_target_dsm_corr_measure(ds,'target_dsm',target_dsm);
    %     %
    %     % Pearson correlation is about 0.56
    %     cosmo_disp(dcm_ds)
    %     %|| .samples
    %     %||   [ 0.562 ]
    %     %|| .sa
    %     %||   .labels
    %     %||     { 'rho' }
    %     %||   .metric
    %     %||     { 'correlation' }
    %     %||   .type
    %     %||     { 'Pearson' }
    %     %
    %     % do not consider classes 3 and 5
    %     target_dsm([3,5],:)=NaN;
    %     target_dsm(:,[3,5])=NaN;
    %     target_dsm(3,3)=0;
    %     target_dsm(5,5)=0;
    %     %
    %     % show the updated target dissimilarity matrix
    %     cosmo_disp(target_dsm);
    %     %|| [   0         0       NaN         1       NaN         1
    %     %||     0         0       NaN         1       NaN         1
    %     %||   NaN       NaN         0       NaN       NaN       NaN
    %     %||     1         1       NaN         0       NaN         1
    %     %||   NaN       NaN       NaN       NaN         0       NaN
    %     %||     1         1       NaN         1       NaN         0 ]
    %     %
    %     % compute similarity between pairw-wise similarity of the
    %     % patterns in the dataset and the target dissimilarity matrix
    %     dcm_ds=cosmo_target_dsm_corr_measure(ds,'target_dsm',target_dsm);
    %     %
    %     % Correlation is different because classes 3 and 5 were left out
    %     cosmo_disp(dcm_ds)
    %     %|| .samples
    %     %||   [ 0.705 ]
    %     %|| .sa
    %     %||   .labels
    %     %||     { 'rho' }
    %     %||   .metric
    %     %||     { 'correlation' }
    %     %||   .type
    %     %||     { 'Pearson' }
    %
    %
    %
    % Notes:
    %   - for group analysis, correlations can be fisher-transformed
    %     through:
    %       dcm_ds.samples=atanh(dcm_ds.samples)
    %   - it is recommended to set the 'center_data' to true when using
    %     the default 'correlation' metric, as this removes a main effect
    %     common to all samples; but note that this option is disabled by
    %     default due to historical reasons.
    %   - elements in the *_dsm dissimilarity matrices can have NaNs, in which
    %     case their value, as well as the corresponding location in the
    %     dataset's samples, are ignored. Masking is done prior to z-score
    %     normalization.
    %
    % Reference:
    %   - Diedrichsen, J., & Kriegeskorte, N. (2017). Representational
    %     models: A common framework for understanding encoding,
    %     pattern-component, and representational-similarity analysis.
    %     PLoS computational biology, 13(4), e1005508.
    %
    %
    % #   For CoSMoMVPA's copyright information and license terms,   #
    % #   see the COPYING file distributed with CoSMoMVPA.           #

    % process input arguments
    params = cosmo_structjoin('type', 'Pearson', ...
                              'metric', 'correlation', ...
                              'center_data', false, ...
                              varargin);

    check_input(ds);
    check_params(params);

    % - compute the pair-wise distance between all dataset samples using
    %   cosmo_pdist

    samples = ds.samples;
    if params.center_data
        samples = bsxfun(@minus, samples, mean(samples, 1));
    end

    samples_pdist = cosmo_pdist(samples, params.metric)';

    has_model_dsms = isfield(params, 'glm_dsm');

    if has_model_dsms
        ds_sa = linear_regression_dsm(samples_pdist, params);
    else
        ds_sa = correlation_dsm(samples_pdist, params);
    end

    check_output(ds, ds_sa);

function check_output(input_ds, output_ds_sa)
    if any(isnan(output_ds_sa.samples))
        if any(isnan(input_ds.samples(:)))
            msg = ['Input dataset has NaN values, which results in '...
                   'NaN values in the output. Consider masking the '...
                   'dataset to remove NaN values'];
        elseif any(var(input_ds.samples) == 0)
            msg = ['Input dataset has constant or infinite features, ', ...
                   'which results in NaN values in the output. '...
                   'Consider masking the dataset to remove constant '...
                   'or non-finite features, for example using '...
                   'cosmo_remove_useless_data'];
        else
            msg = ['Output has NaN values, even though the input does '...
                   'not. This can be due to the presence of constant '...
                   'features and/or non-finite values in the input, '...
                   'and/or target similarity structures with constant '...
                   'and/of non-finite data. When in doubt, please '...
                   'contact the CoSMoMVPA developers'];
        end
        cosmo_warning(msg);
    end

function ds_sa = correlation_dsm(samples_pdist, params)
    npairs_dataset = numel(samples_pdist);

    % get target dsm in vector form
    [target_dsm_vec, target_msk] = get_dsm_vec_from_struct(params, ...
                                                           'target_dsm', ...
                                                           npairs_dataset);

    if ~isempty(target_msk)
        samples_pdist = samples_pdist(target_msk);
    end

    % ensure the size of the dataset matches the matrix
    has_regress_dsm = isfield(params, 'regress_dsm');
    if has_regress_dsm
        [regress_dsm_mat, regress_msk] = get_dsm_mat_from_vector_or_cell( ...
                                                                         params.regress_dsm, ...
                                                                         npairs_dataset);

        if ~isequal(target_msk, regress_msk)
            error(['NaN mask non-match between ''target_dsm'' '...
                   'and ''regress_dsm''']);
        end

        [samples_pdist(:), target_dsm_vec(:)] = regress_out( ...
                                                            samples_pdist, ...
                                                            target_dsm_vec, ...
                                                            regress_dsm_mat, ...
                                                            params.type);
        % regressed out samples are either based on Pearson or Spearman
        % correlation; Pearson correlations are computed for the final
        % correlation computation
        corr_type = 'Pearson';
    else
        corr_type = params.type;
    end

    rho = cosmo_corr(samples_pdist, target_dsm_vec, corr_type);

    % store results
    ds_sa = struct();
    ds_sa.samples = rho;
    ds_sa.sa.labels = {'rho'};
    ds_sa.sa.metric = {params.metric};
    ds_sa.sa.type = {params.type};

function ds_sa = linear_regression_dsm(samples_pdist, params)
    if ~strcmp(params.type, 'Pearson')
        error(['For linear regression, only ''Pearson'' is '...
               'currently allowed']);
    end
    npairs_dataset = numel(samples_pdist);

    [dsm_mat, msk] = get_dsm_mat_from_vector_or_cell(params.glm_dsm, ...
                                                     npairs_dataset);

    % normalize matrices
    dsm_mat_zscore = cosmo_normalize(dsm_mat, 'zscore');

    if ~isempty(msk)
        samples_pdist = samples_pdist(msk);
    end

    % normalize data
    samples_pdist_zscore = cosmo_normalize(samples_pdist(:), 'zscore');

    % estimate betas based on masked samples
    betas = dsm_mat_zscore \ samples_pdist_zscore;

    % construct labels
    nvec = size(dsm_mat_zscore, 2);
    labels = cell(nvec, 1);
    for k = 1:nvec
        labels{k} = sprintf('beta%d', k);
    end

    ds_sa = struct();
    ds_sa.samples = betas;
    ds_sa.sa.labels = labels;
    ds_sa.sa.metric = repmat({params.metric}, nvec, 1);

function [ds_resid, target_resid] = regress_out(ds_pdist, ...
                                                target_dsm_vec, ...
                                                regress_dsm_mat, ...
                                                partial_corr_type)

    if strcmp(partial_corr_type, 'Spearman')
        % use ranks of the dissimilarity matrices
        ds_pdist = cosmo_tiedrank(ds_pdist, 1);
        target_dsm_vec = cosmo_tiedrank(target_dsm_vec, 1);
        regress_dsm_mat = cosmo_tiedrank(regress_dsm_mat, 1);
    elseif strcmp(partial_corr_type, 'Kendall')
        error('Kendall cannot be used with the ''regress_dsm'' option');
    end

    % set up design matrix
    nsamples = size(ds_pdist, 1);
    regr = [regress_dsm_mat ones(nsamples, 1)];

    % put ds_pdist and target_dsm_vec together
    both = [ds_pdist target_dsm_vec];

    % compute residuals
    both_resid = both - regr * (regr \ both);

    ds_resid = both_resid(:, 1);
    target_resid = both_resid(:, 2);

function [dsm_mat, common_msk] = get_dsm_mat_from_vector_or_cell(dsm_cell, ...
                                                                 npairs_dataset)
    if isnumeric(dsm_cell)
        dsm_cell = {dsm_cell};
    elseif ~iscell(dsm_cell)
        error('dsm inputs must be provided in a cell');
    end

    n = numel(dsm_cell);
    for k = 1:n
        name = sprintf('.model_dsms{%d}', k);
        [vec, msk_k] = get_dsm_vec(dsm_cell{k}, npairs_dataset, name);

        if k == 1
            nrows = numel(vec);
            dsm_mat = zeros(nrows, n);

            common_msk = msk_k;
        end

        if ~isempty(msk_k) && k ~= 1 && ~isequal(common_msk, msk_k)
            error('DSMs NaN mask mismatch for matrices %d and %d', ...
                  1, k);
        end

        dsm_mat(:, k) = vec;
    end

function [dsm_vec, msk] = get_dsm_vec_from_struct(params, name, npairs_dataset)
    % helper function to get dsm in vector form
    if ~isfield(params, name)
        error('Missing parameter ''%s''', name);
    end

    dsm = params.(name);
    [dsm_vec, msk] = get_dsm_vec(dsm, npairs_dataset, ['''' name '''']);

function [dsm_vec, msk] = get_dsm_vec(dsm, npairs_dataset, name)
    % helper function to get dsm in column vector form
    if ~isnumeric(dsm)
        error('dsm inputs must be numeric, found %s', class(dsm));
    end

    sz = size(dsm);

    if sz(1) == 1
        dsm_vec = dsm';
    elseif sz(2) == 1
        dsm_vec = dsm;
    else
        % convert square matrix to vector
        dsm_vec = cosmo_squareform(dsm, 'tovector')';
    end

    if npairs_dataset ~= numel(dsm_vec)
        error(['Sample size mismatch between dataset (%d pairs) '...
               'and %s in vector form (%d pairs)'], ...
              npairs_dataset, name, numel(dsm_vec));
    end

    msk = [];

    nan_msk = isnan(dsm_vec);
    if any(nan_msk)
        % apply mask
        msk = ~nan_msk;
        dsm_vec = dsm_vec(msk);
    end

function check_input(ds)
    % for safety require targets to be 1:N
    has_sa = isstruct(ds) && isfield(ds, 'sa');
    if ~has_sa || ~isfield(ds.sa, 'targets') || ~isfield(ds, 'samples')
        error('Missing field .sa.targets or .samples');
    end

    nsamples = size(ds.samples, 1);

    if ~isequal(ds.sa.targets', 1:nsamples)
        msg = sprintf('.sa.targets must be (1:%d)''', nsamples);
        if isequal(unique(ds.sa.targets), (1:nsamples)')
            msg = sprintf(['%s\nMultiple samples with the same chunks '...
                           'can be averaged using cosmo_fx'], msg);
        else
            msg = sprintf(['%s\nConsider setting .sa.chunks to q, where '...
                           '[~,~,q]=unique(ds.sa.targets)', msg]);
        end
        error(msg);
    end

function check_params(params)
    if isfield(params, 'glm_dsm')
        if isfield(params, 'regress_dsm') || isfield(params, 'target_dsm')
            error(['''glm_dsm'' cannot be used with ''regress_dsm'''...
                   'or ''target_dsm''']);
        end
    elseif ~isfield(params, 'target_dsm')
        error('''target_dsm'' or ''glm_dsm'' option is required');
    end

    ensure_is_valid_correlation(params, 'type');

function ensure_is_valid_correlation(params, key)
    value = params.(key);

    allowed_types = {'Spearman', 'Pearson', 'Kendall'};

    if ~(ischar(value) ...
         && any(strcmp(value, allowed_types)))
        error('''%s'' argument must be one of: ''%s''', ...
              key, cosmo_strjoin(allowed_types, '''%s'', '));
    end