cosmo distatis skl

function res = cosmo_distatis(ds, varargin)
    % apply DISTATIS measure to each feature
    %
    % res=cosmo_statis_measure(ds, opt)
    %
    % Inputs:
    %    ds               dataset struct with dissimilarity values; usually
    %                     the output from @cosmo_dissimilarity_matrix_measure
    %                     applied to each subject followed by cosmo_stack. It
    %                     can also be a cell with datasets (one per subject).
    %    'return', d      d can be 'distance' (default) or 'crossproduct'.
    %                     'distance' returns a distance matrix, whereas
    %                     'crossproduct' returns a crossproduct matrix
    %    'split_by', s    sample attribute that discriminates chunks
    %                     (participants) (default: 'chunks')
    %    'shape', sh      shape of output if it were unflattened using
    %                     cosmo_unflatten, either 'square' (default) or
    %                     'triangle' (which gives the lower diagonal of the
    %                     distance matrix)
    %
    % Returns:
    %    res              result dataset struct with feature-wise optimal
    %                     compromise distance matrix across subjects
    %      .samples
    %
    %
    % Example:
    %     % (This example cannot be documentation tested using Octave,
    %     %  since Octave does not allow for-loops with evalc)
    %     cosmo_skip_test_if_no_external('matlab');
    %     %
    %     ds=cosmo_synthetic_dataset('nsubjects',5,'nchunks',1,'ntargets',4);
    %     %
    %     % define neighborhood (here a searchlight with radius of 1 voxel)
    %     nbrhood=cosmo_spherical_neighborhood(ds,'radius',1,'progress',false);
    %     %
    %     % define measure
    %     measure=@cosmo_dissimilarity_matrix_measure;
    %     % each subject is a chunk
    %     ds.sa.chunks=ds.sa.subject;
    %     % compute DSM for each subject
    %     sp=cosmo_split(ds,'chunks');
    %     for k=1:numel(sp)
    %         sp{k}=cosmo_searchlight(sp{k},nbrhood,measure,'progress',false);
    %         sp{k}.sa.chunks=ones(6,1)*k;
    %     end
    %     % merge results
    %     dsms=cosmo_stack(sp);
    %     %
    %     r=cosmo_distatis(dsms,'return','distance','progress',false);
    %     cosmo_disp(r);
    %     %|| .samples
    %     %||   [     0         0         0         0         0         0
    %     %||     0.818      1.09      0.77     0.653      1.03     0.421
    %     %||     0.869       1.3      1.06      1.04     0.932      1.07
    %     %||       :         :         :         :         :         :
    %     %||      1.16     0.889      0.99     0.631      1.48     0.621
    %     %||     0.268     0.952     0.965     0.462     0.943      1.04
    %     %||         0         0         0         0         0         0 ]@16x6
    %     %|| .fa
    %     %||   .center_ids
    %     %||     [ 1         2         3         4         5         6 ]
    %     %||   .i
    %     %||     [ 1         2         3         1         2         3 ]
    %     %||   .j
    %     %||     [ 1         1         1         2         2         2 ]
    %     %||   .k
    %     %||     [ 1         1         1         1         1         1 ]
    %     %||   .nvoxels
    %     %||     [ 3         4         3         3         4         3 ]
    %     %||   .radius
    %     %||     [ 1         1         1         1         1         1 ]
    %     %||   .quality
    %     %||     [ 0.685     0.742     0.617     0.648     0.757     0.591 ]
    %     %||   .nchunks
    %     %||     [ 5         5         5         5         5         5 ]
    %     %|| .a
    %     %||   .fdim
    %     %||     .labels
    %     %||       { 'i'  'j'  'k' }
    %     %||     .values
    %     %||       { [ 1         2         3 ]  [ 1         2 ]  [ 1 ] }
    %     %||   .sdim
    %     %||     .labels
    %     %||       { 'targets1'  'targets2' }
    %     %||     .values
    %     %||       { [ 1    [ 1
    %     %||           2      2
    %     %||           3      3
    %     %||           4 ]    4 ] }
    %     %||   .vol
    %     %||     .mat
    %     %||       [ 2         0         0        -3
    %     %||         0         2         0        -3
    %     %||         0         0         2        -3
    %     %||         0         0         0         1 ]
    %     %||     .dim
    %     %||       [ 3         2         1 ]
    %     %||     .xform
    %     %||       'scanner_anat'
    %     %|| .sa
    %     %||   .targets1
    %     %||     [ 1
    %     %||       2
    %     %||       3
    %     %||       :
    %     %||       2
    %     %||       3
    %     %||       4 ]@16x1
    %     %||   .targets2
    %     %||     [ 1
    %     %||       1
    %     %||       1
    %     %||       :
    %     %||       4
    %     %||       4
    %     %||       4 ]@16x1
    %
    % Reference:
    %   - Abdi, H., Valentin, D., O?Toole, A. J., & Edelman, B. (2005).
    %     DISTATIS: The analysis of multiple distance matrices. In
    %     Proceedings of the IEEE Computer Society: International conference
    %     on computer vision and pattern recognition, San Diego, CA, USA
    %     (pp. 42?47).
    %
    % Notes:
    %   - DISTATIS tries to find an optimal compromise distance matrix across
    %     the different samples (participants)
    %   - Output can be reshape to matrix or array form using
    %     cosmo_unflatten(res,1)
    %
    % #   For CoSMoMVPA's copyright information and license terms,   #
    % #   see the COPYING file distributed with CoSMoMVPA.           #

    cosmo_check_external('distatis');

    defaults.return = 'distance';
    defaults.split_by = 'chunks';
    defaults.shape = 'square';
    defaults.mask_output = [];
    defaults.progress = 100;
    defaults.feature_ids = [];
    defaults.autoscale = true;
    defaults.abs_correlation = false;
    defaults.weights = 'eig';

    opt = cosmo_structjoin(defaults, varargin);

    subject_cell = get_subject_data(ds, opt);
    nsubj = numel(subject_cell);

    [dsms, nclasses, dim_labels, dim_values] = get_dsms(subject_cell);

    feature_ids = get_feature_ids(size(dsms{1}, 3), opt);
    nfeatures = numel(feature_ids);

    quality = zeros(1, nfeatures);
    nobservations = zeros(1, nfeatures);

    prev_msg = '';
    clock_start = clock();
    show_progress = nfeatures > 1 && opt.progress;

    for k = 1:nfeatures
        feature_id = feature_ids(k);
        x = zeros(nclasses * nclasses, nsubj);

        for j = 1:nsubj
            dsm = dsms{j}(:, :, feature_id);
            x(:, j) = distance2crossproduct(dsm, opt.autoscale);
        end

        [x, subj_msk] = cosmo_remove_useless_data(x);
        nkeep = sum(subj_msk);

        % equivalent, but slower:
        % [e,v]=eigs(c,1);

        [ew, v] = get_weights(x, feature_id, nkeep, opt);

        % compute compromise
        compromise = x * ew;

        result = convert_compromise(compromise, opt);

        if feature_id == 1
            % allocate space
            samples = zeros(numel(result), nfeatures);
        end

        samples(:, k) = result;

        quality(:, k) = v / nkeep;
        nobservations(:, k) = nkeep;

        if show_progress && (k < 10 || ...
                             mod(k, opt.progress) == 0 || ...
                             k == nfeatures)
            status = sprintf('quality=%.3f%% (avg)', mean(quality(1:k)));
            prev_msg = cosmo_show_progress(clock_start, k / nfeatures, ...
                                           status, prev_msg);
        end
    end

    % set output in either triangular or square shape
    [res, i, j] = get_samples_in_shape(samples, nclasses, opt.shape);
    res = copy_fields(ds, res, {'fa', 'a'});

    % add attributes
    res.fa.quality = quality;
    res.fa.nchunks = nobservations;
    res.a.sdim = struct();
    res.a.sdim.labels = dim_labels;
    res.a.sdim.values = dim_values;

    res.sa.(dim_labels{1}) = i(:);
    res.sa.(dim_labels{2}) = j(:);

    cosmo_check_dataset(res);

function [res, i, j] = get_samples_in_shape(samples, nclasses, shape)
    res = struct();
    switch shape
        case 'triangle'
            [msk, i, j] = distance_matrix_mask(nclasses);
            res.samples = samples(msk(:), :);
        case 'square'
            res.samples = samples;
            [i, j] = find(ones(nclasses));
        otherwise
            error('unsupported direction %s', shape);
    end

function dst = copy_fields(src, dst, keys)
    for k = 1:numel(keys)
        key = keys{k};
        if isfield(src, key)
            dst.(key) = src.(key);
        end
    end

function feature_ids = get_feature_ids(nfeatures, opt)
    feature_ids = opt.feature_ids;
    if isempty(feature_ids)
        feature_ids = 1:nfeatures;
    end

function [ew, v] = get_weights(x, feature_id, nkeep, opt)
    switch opt.weights
        case 'eig'
            [ew, v] = eigen_weights(x, feature_id);

        case 'uniform'
            % all the same (allowing for comparison with 'eig')
            ew = ones(nkeep, 1) / nkeep;
            v = 0;

        otherwise
            error('illegal weight %s', opt.weights);
    end

function subject_cell = get_subject_data(ds, opt)
    if isstruct(ds)
        subject_cell = cosmo_split(ds, opt.split_by);
    else
        subject_cell = ds;
    end

    if numel(subject_cell) == 0
        error('empty input');
    end

function [ew, v] = eigen_weights(x, feature_id)

    c = cosmo_corr(x);

    negative_c = c < 0;

    if any(negative_c(:))

        [i, j] = find(negative_c);
        error(['feature %d has negative correlation between '...
               'sample %d and %d, which is not supported by '...
               'distatis. DISTATIS assumes that the similarity '...
               'data from all samples (typically: participants) '...
               'correlate positively. Because that is not the '...
               'case, you cannot use DISTATIS analysis on this '...
               'data. '], ...
              feature_id, i(1), j(1));
    end

    [v, e] = fast_eig1(c);

    if all(e < 0)
        e = -e;
    end

    assert(all(e > 0));
    assert(v > 0);

    % normalize first eigenvector
    ew = e / sum(e);

function result = convert_compromise(compromise, opt)
    switch opt.return
        case 'crossproduct'
            result = compromise;
        case 'distance'
            result = crossproduct2distance(compromise);
        otherwise
            error('illegal opt.return');
    end

function z = crossproduct2distance(x)
    n = sqrt(numel(x));
    e = ones(n, 1);
    d = x(1:(n + 1):end);
    dd = d * e';
    ddt = dd';
    y = dd(:) + ddt(:) - 2 * x;
    z = ensure_distance_vector(y);

function assert_symmetric(x, tolerance)
    if nargin < 2
        tolerance = 1e-8;
    end

    % assert x is a square matrix
    sz = size(x);
    assert(isequal(sz, sz([2 1])));

    xx = x' - x;

    msk = xx > tolerance;
    if any(msk)
        [i, j] = find(msk, 1);
        error('not symmetric: x(%d,%d)=%d ~= %d=x(%d,%d)', ...
              i, j, x(i, j), x(j, i), j, i);
    end

function z_vec = distance2crossproduct(x, autoscale)

    n = size(x, 1);
    e = ones(n, 1);
    m = e * (1 / n);
    ee = eye(n) - e * m';
    y = -.5 * ee * (x + x') * ee';
    if autoscale
        z = (1 / fast_eig1(y)) * y;
    else
        z = y;
    end
    assert_symmetric(z);
    % equivalent, but slower:
    % z=(1/eigs(y,1))*y(:);

    z_vec = z(:);

function [lambda, pivot] = fast_eig1(x)
    % returns the first eigenvalue in lambda, and the corresponding
    % eigenvector in pivot
    if cosmo_wtf('is_matlab')
        [pivot, lambda] = eigs(x, 1);
    else
        % There seems a bug in Octave for 'eigs',
        % so use 'eig' instead.
        % http://savannah.gnu.org/bugs/?44004
        [e, v] = eig(x);
        diag_v = diag(v);

        % find largest eigenvalue and eigenvector
        [lambda, i] = max(diag_v);
        pivot = e(:, i);
    end

    % The code below is disabled because under certain circumstances
    % it would return a near-zero eigenvalue if indeed one eigenvalue (but
    % not the largest one) is zero.
    % % compute first (largest) eigenvalue and corresponding eigenvector
    % % using power iteration method; benchmarking suggests this can be up to
    % % five times as fast as using eigs(x,1)
    % n=size(x,1);
    % pivot=ones(n,1);
    % tolerance=1e-8;
    % max_iter=1000;
    %
    % old_lambda=NaN;
    % for k=1:max_iter
    %     z=x*pivot;
    %     pivot=z / norm(z);
    %
    %     lambda=pivot'*z;
    %     if abs(lambda-old_lambda)/lambda<tolerance
    %         z=x*pivot;
    %         pivot=z / sqrt(sum(z.^2));
    %
    %         lambda=pivot'*z;
    %         return
    %     end
    %     old_lambda=lambda;
    % end
    %
    % % matlab fallback
    % [pivot,lambda]=eigs(x,1);

function y = ensure_distance_vector(x)
    tolerance = 1e-8;

    n = sqrt(numel(x));
    xsq = reshape(x, n, n);

    dx = diag(xsq);
    assert(all(dx < tolerance));

    xsq = xsq - diag(dx);

    delta = xsq - xsq';
    assert(all(delta(:) < tolerance));

    xsq = .5 * (xsq + xsq');
    y = xsq(:);

function [dsms, nclasses, dim_labels, dim_values] = get_dsms(data_cell)
    nsubj = numel(data_cell);

    % allocate
    dsms = cell(nsubj, 1);
    for k = 1:numel(data_cell)
        data = data_cell{k};

        % get data
        [dsm, dim_labels, dim_values, is_ds] = get_dsm(data);

        % store data
        dsms{k} = dsm;

        if k == 1
            nclasses = size(dsm, 1);
            first_dim_labels = dim_labels;
            first_dim_values = dim_values;

            data_first = data;
        else

            if ~isequal(first_dim_labels, dim_labels)
                error('dim label mismatch between subject 1 and %d', k);
            end
            if ~isequal(first_dim_values, dim_values)
                error('dim label mismatch between subject 1 and %d', k);
            end

            % check for compatibility over subjects, raises an error if not
            % kosher
            if is_ds
                cosmo_stack({cosmo_slice(data, 1), ...
                             cosmo_slice(data_first, 1)}, 1, 'unique');
            end
        end
    end

function [msk, i, j] = distance_matrix_mask(nclasses)
    msk = triu(repmat(1:nclasses, nclasses, 1), 1)' > 0;
    [i, j] = find(msk);

function [dsm, dim_labels, dim_values, is_ds] = get_dsm(data)
    is_ds = isstruct(data);
    if is_ds
        [dsm, dim_labels, dim_values] = cosmo_unflatten(data, 1);
    elseif isnumeric(data)
        sz = size(data);
        if numel(sz) ~= 2
            error('only vectorized distance matrices are supported');
        end
        [n, nfeatures] = size(data);

        side = (1 + sqrt(1 + 8 * n)) / 2; % so that side*(side-1)/2==n
        if ~isequal(side, round(side))
            error(['size %d of input vector is not correct for '...
                   'the number of elements below the diagonal of a '...
                   'square (distance) matrix'], n);
        end

        [msk, i, j] = distance_matrix_mask(side);
        dsm = zeros([side, side, nfeatures]);

        assert(numel(i) == n);
        for pos = 1:n
            dsm(i(pos), j(pos), :) = data(pos, :);
        end

        sq1 = cosmo_squareform(data(:, 1));
        dsm1 = dsm(:, :, 1);
        assert(isequal(sq1, dsm1 + dsm1'));

        dim_labels = {'targets1', 'targets2'};
        dim_values = {(1:side)', (1:side)'};
    else
        error('illegal input: expect dataset struct, or cell with arrays');
    end