cosmo flatten skl

function ds=cosmo_flatten(arr, dim_labels, dim_values, dim, varargin)
% flattens an arbitrary array to a dataset structure
%
% ds=cosmo_flatten(arr, dim_labels, dim_values, dim[, ...])
%
% Inputs:
%   arr                S_1 x ... x S_K x Q input array if (dim==1), or
%                      P x S_1 x ... x S_K input array if (dim==2)
%   dim_labels         1xK cell containing labels for each dimension but
%                      the first one.
%   dim_values         1xK cell with S_J values (J in 1:K) corresponding to
%                      the labels in each of the K dimensions.
%   dim                dimension along which to flatten, either 1 (samples)
%                      or 2 (features; default)
%   'matrix_labels',m  Allow labels in the cell string m to be matrices
%                      rather than vectors. Currently the only use case is
%                      the 'pos' attribute for MEEG source space data.
%
% Output:
%   ds                 dataset structure, with fields:
%      .samples        PxQ data for P samples and Q features.
%      .a.dim.labels   Kx1 cell with the values in dim_labels
%      .a.dim.values   Kx1 cell with the values in dim_values. The i-th
%                      element has S_i elements along dimension dim
%      .fa.(label)     for each label in a.dim.labels it contains the
%      .samples        PxQ data for P samples and Q features, where
%                      Q=prod(S_*) if dim==1 and P=prod(S_*) if dim==2
%      .a.Xdim.labels  1xK cell with the values in dim_labels (X=='s' if
%                      dim==1, and 'f' if dim==2); the M-th element must
%                      have S_M values.
%      .a.Xdim.values  1xK cell with the values in dim_values; the M-th
%                      element must have S_M values.
%      .Xa.(label)     for each label in a.Xdim.labels it contains the
%                      sub-indices for the K dimensions. It is ensured
%                      that for every dimension J in 1:K, all values in
%                      ds.fa.(a.dim.labels{J}) are in the range 1:S_K.
%
% Examples:
%     % typical usage: flatten features in 2x3x5 array, 1 sample
%     data=reshape(1:30, [1 2,3,5]);
%     ds=cosmo_flatten(data,{'i','j','k'},{1:2,1:3,{'a','b','c','d','e'}});
%     cosmo_disp(ds)
%     %|| .samples
%     %||   [ 1         2         3  ...  28        29        30 ]@1x30
%     %|| .fa
%     %||   .i
%     %||     [ 1 2 1  ...  2 1 2 ]@1x30
%     %||   .j
%     %||     [ 1 1 2  ...  2 3 3 ]@1x30
%     %||   .k
%     %||     [ 1 1 1  ...  5 5 5 ]@1x30
%     %|| .a
%     %||   .fdim
%     %||     .labels
%     %||       { 'i'  'j'  'k' }
%     %||     .values
%     %||       { [ 1 2 ]  [ 1 2 3 ]  { 'a'  'b'  'c'  'd'  'e' } }
%
%     % flatten samples in 1x1x2x3 array, 5 features
%     data=reshape(1:30, [1,1,2,3,5]);
%     ds=cosmo_flatten(data,{'i','j','k','m'},{1,'a',(1:2)',(1:3)'},1);
%     cosmo_disp(ds);
%     %|| .samples
%     %||   [ 1         7        13        19        25
%     %||     2         8        14        20        26
%     %||     3         9        15        21        27
%     %||     4        10        16        22        28
%     %||     5        11        17        23        29
%     %||     6        12        18        24        30 ]
%     %|| .sa
%     %||   .i
%     %||     [ 1
%     %||       1
%     %||       1
%     %||       1
%     %||       1
%     %||       1 ]
%     %||   .j
%     %||     [ 1
%     %||       1
%     %||       1
%     %||       1
%     %||       1
%     %||       1 ]
%     %||   .k
%     %||     [ 1
%     %||       2
%     %||       1
%     %||       2
%     %||       1
%     %||       2 ]
%     %||   .m
%     %||     [ 1
%     %||       1
%     %||       2
%     %||       2
%     %||       3
%     %||       3 ]
%     %|| .a
%     %||   .sdim
%     %||     .labels
%     %||       { 'i'  'j'  'k'  'm' }
%     %||     .values
%     %||       { [ 1 ]  'a'  [ 1    [ 1
%     %||                       2 ]    2
%     %||                              3 ] }
%
%
% Notes:
%   - Intended use is for flattening fMRI or MEEG datasets
%   - This function is the inverse of cosmo_unflatten.
%
% See also: cosmo_unflatten, cosmo_fmri_dataset, cosmo_meeg_dataset
%
% #   For CoSMoMVPA's copyright information and license terms,   #
% #   see the COPYING file distributed with CoSMoMVPA.           #

    defaults.matrix_labels=cell(0);
    opt=cosmo_structjoin(defaults,varargin{:});

    if nargin<4, dim=2; end

    switch dim
        case 1
            do_transpose=true;
            attr_name='sa';
            dim_name='sdim';
        case 2
            do_transpose=false;
            attr_name='fa';
            dim_name='fdim';
        otherwise
            error('illegal dim: must be 1 or 2');
    end

    if do_transpose
        % switch samples and features
        ndim=numel(dim_labels);
        nfeatures=size(arr,ndim+1);
        if nfeatures==1
            arr=reshape(arr,[1 size(arr)]);
        else
            arr=shiftdim(arr,ndim);
        end
        dim_values=cellfun(@transpose,dim_values,'UniformOutput',false);
    end

    [samples,dim_values,attr]=flatten_features(arr, dim_labels, ...
                                            dim_values, opt);

    if do_transpose
        samples=samples';
        attr=transpose_attr(attr);
        dim_values=cellfun(@transpose,dim_values,'UniformOutput',false);
    end

    ds=struct();
    ds.samples=samples;
    ds.(attr_name)=attr;
    ds.a.(dim_name).labels=dim_labels;
    ds.a.(dim_name).values=dim_values;

function attr=transpose_attr(attr)
    keys=fieldnames(attr);
    for k=1:numel(keys)
        key=keys{k};
        value=attr.(key);
        attr.(key)=value';
    end

function [samples,dim_values,attr]=flatten_features(arr, dim_labels, ...
                                                        dim_values, opt)
    % helper function to flatten features

    ndim=numel(dim_labels);
    if ndim ~= numel(dim_values)
        error('expected %d dimensions, found %d',ndim,numel(dim_values));
    elseif numel(size(arr))>(ndim+1)
        error('Array has %d dimensions, expected <= %d',...
                                        numel(size(arr)),ndim+1);
    end


    % allocate space for output
    attr=struct();

    % number of values in remaining dimensions
    % (supports the case that arr is of size [...,1]
    [dim_sizes,dim_values]=get_dim_sizes(arr,dim_labels,dim_values,opt);

    for dim=1:ndim
        % set values for dim-th dimension
        dim_label=dim_labels{dim};
        dim_value=dim_values{dim};

        nvalues=size(dim_value,2);

        % set the indices
        indices=1:nvalues;

        % make an array lin_values that has size 1 in every dimension
        % except for the 'dim'-th one, where it has size 'nvalues'.
        singleton_size=ones(1,ndim);
        singleton_size(dim)=nvalues;
        if ndim==1
            % reshape only works with >=2 dimensions
            lin_values=indices;
        else
            lin_values=reshape(indices,singleton_size);
        end

        % now the lin_values have to be tiled (using repmat). The number of
        % repeats is 'dim_sizes'('k') for all 'k' except for 'dim',
        % where it is 1 (as it has 'nvalues' in that dimension already).
        rep_size=dim_sizes;
        rep_size(dim)=1;

        rep_values=repmat(lin_values, rep_size(:)');

        % store indices as a row vector.
        attr.(dim_label)=reshape(rep_values, 1, []);
    end

    % get array and sample sizes
    nsamples=size(arr,1);
    nfeatures=prod(dim_sizes);

    samples=reshape(arr, nsamples, nfeatures);


function [dim_sizes, dim_values]=get_dim_sizes(arr,dim_labels,dim_values,opt)
    ndim=numel(dim_values);
    dim_sizes=zeros(1,ndim);

    for dim=1:ndim
        dim_label=dim_labels{dim};
        dim_value=dim_values{dim};

        if cosmo_match({dim_label},opt.matrix_labels)
            dim_size=size(dim_value,2);
        else
            if ~isvector(dim_value)
                error(['Label ''%s'' (dimension %d) must be a vector, '...
                        'because it was not specified as a matrix '...
                        'dimension in the ''matrix_fields'' option'],...
                        dim_label, dim);
            end
            dim_size=numel(dim_value);
            dim_values{dim}=dim_value(:)'; % make it a row vector
        end


        if dim_size ~= size(arr,dim+1)
            error(['Label ''%s'' (dimension %d) has %d values, ',...
                        'expected %d based on the array input'],...
                    dim_label, dim, dim_size, size(arr,dim+1));
        end

        dim_sizes(dim)=dim_size;
    end