cosmo tiedrank skl

function ranks=cosmo_tiedrank(data, dim)
% Compute ranks for the input along the specified dimension
%
% ranks=cosmo_tiedrank(data[, dim])
%
% Inputs:
%   data                        numeric N-dimensional array
%   dim                         optional dimension along which the ranks
%                               are computed (default: 1)
%
% Output:
%   ranks                       numeric N-dimensional array with the same
%                               size as the input containing the rank of
%                               each vector along the dim-th dimension.
%                               Equal values have the same rank, which is
%                               the average of the rank the values would
%                               have if they differed by a minimal amount.
%                               NaN values in the input result in a NaN
%                               values in the output at the corresponding
%                               locations.
%                               If dim is greater than the number of
%                               dimensions in data, then all values in rank
%                               are one (or NaN of the corresponding value
%                               in data is NaN).
%
% Examples:
%     cosmo_tiedrank([1 2 2],2)
%     %|| [ 1 2.5 2.5]
%
%     cosmo_tiedrank([NaN 2 2;3 NaN 4],1)
%     %|| [ NaN     1     1;
%     %||     1   NaN     2];
%
%     cosmo_tiedrank([NaN 2 2;3 NaN 4],2)
%     %|| [ NaN   1.5   1.5;
%     %||     1   NaN     2];
%
%     cosmo_tiedrank([2 4 3 3 3 3 5 5 5],2)
%     %|| [ 1.0 6.0 3.5 3.5 3.5 3.5 8.0 8.0 8.0 ]
%
% Notes:
% - Unlike the Matlab builtin function 'tiedrank' (part of the statistics
%   toolbox), the meaning of the second argument is the dimension along
%   which the ranks are computed.
%
% #   For CoSMoMVPA's copyright information and license terms,   #
% #   see the COPYING file distributed with CoSMoMVPA.           #

    if nargin<2
        dim=1;
    end

    check_inputs(data, dim);
    orig_size=size(data);

    if numel(orig_size)<dim
        % if the input data does not have enough data, then the output
        % consists of an array with only ones (or NaNs, if present)
        ranks=singleton_ranks(data);
        return;
    end

    [values,idx]=sort(data,dim);

    data_is_vector=numel(orig_size)<=2 && orig_size(3-dim)==1;
    if data_is_vector
        ranks=vector_tied_rank(values(:), idx(:));
        if orig_size(1)==1
            % transpose to turn it back into a row vector
            ranks=ranks';
        end
        return
    end


    % make the dim-th dimension the first dimension
    values_sh=shiftdim(values,dim-1);
    idx_sh=shiftdim(idx,dim-1);
    sh_size=size(values_sh);

    count_along_dim=size(values_sh,1);

    % reshape into a matrix
    values_mat=reshape(values_sh,count_along_dim,[]);
    idx_mat=reshape(idx_sh,count_along_dim,[]);

    % space for output
    ranks_mat=zeros(size(idx_mat));

    % compute for each column vector
    n_col=size(ranks_mat,2);
    for k=1:n_col
        ranks_mat(:,k)=vector_tied_rank(values_mat(:,k),idx_mat(:,k));
    end

    % put back in shape after shiftdim
    ranks_sh=reshape(ranks_mat,sh_size);

    % undo shiftdim
    unshift_count=numel(orig_size)-dim+1;
    ranks=reshape(shiftdim(ranks_sh,unshift_count),orig_size);



function ranks=vector_tied_rank(sorted_values, sort_idx)
% sorted_values and sort_idx are the output from 'sort'
% it is assumes that sorted_values is a vector
    n_values=numel(sorted_values);
    nan_msk=isnan(sorted_values);
    nan_count=sum(nan_msk);
    non_nan_count=numel(sorted_values)-nan_count;

    % first set ranks for values without ties
    ranks=sort_idx+NaN;
    ranks(sort_idx(1:non_nan_count))=1:non_nan_count;

    % now deal with ties
    tie_msk=sorted_values(2:end)==sorted_values(1:(end-1));
    tie_idx=find(tie_msk);
    tie_count=numel(tie_idx);

    k=0;
    while k<tie_count
        k=k+1;

        tie_start=tie_idx(k);
        tie_end=tie_start+1;

        while tie_end<n_values ...
                && sorted_values(tie_end)==sorted_values(tie_end+1)
            tie_end=tie_end+1;
            k=k+1;
        end

        tie_value=(tie_start+tie_end)/2;
        pos=tie_start+(0:(tie_end-tie_start));
        ranks(sort_idx(pos))=tie_value;
    end


function ranks=singleton_ranks(data)
    % all ranks are either NaN or 1
    ranks=ones(size(data));
    ranks(isnan(data))=NaN;


function check_inputs(data, dim)
    if ~isnumeric(data)
        error('First input must be numeric')
    end

    if ~(isnumeric(dim) ...
            && isscalar(dim) ...
            && round(dim)==dim ...
            && dim>0)
        error('Second argument must be numeric integer');
    end