cosmo unflatten

function [arr, dim_labels, dim_values] = cosmo_unflatten(ds, dim, varargin)
    % unflattens a dataset from 2 to (1+K) dimensions.
    %
    % [arr, dim_labels, dim_values]=cosmo_unflatten(ds, [dim, ][,...])
    %
    % Inputs:
    %   ds                 dataset structure, with fields:
    %      .samples        PxQ for P samples and Q features.
    %      .a.Xdim.labels  1xK cell with string labels for each dimension,
    %                      with X='s' for samples (dim=1) or X='f' for features
    %                      (dim=2).
    %      .a.Xdim.values  1xK cell, with S_J values (J in 1:K) corresponding
    %                      to the labels in each of the K dimensions.
    %      .Xa.(label)     for each label in a.Xdim.labels it contains the
    %                      sub-indices for the K dimensions. It is required
    %                      that for every dimension J in 1:K, all values in
    %                      ds.fa.(a.fdim.labels{J}) are in the range 1:S_K, and
    %                      that every combination across labels is unique.
    %   dim                dimension to be unflattened, either 1 (for samples)
    %                      or 2 (for features; default)
    %   'set_missing_to',s value to set missing values to (default: 0)
    %   'matrix_labels',m  Allow labels in the cell string m to be matrices
    %                      rather than vectors. Currently the only use case is
    %                      the 'pos' attribute for MEEG source space data.
    %
    % Returns:
    %   arr                S_1 x ... x S_K x Q array if (dim==1), or
    %                      P x S_1 x ... x S_K array if (dim==2), where
    %                      Q=prod(S_*) if dim==1 and P=prod(S_*) if dim==2
    %   dim_labels         the value of .a.Xdim.labels
    %   dim_values         the value of .a.Xdim.values
    %
    % Example:
    %     % ds is an FMRI dataset with 6 samples, volumes are 3 x 2 x 5 voxels
    %     ds=cosmo_synthetic_dataset('size','normal','type','fmri');
    %     size(ds.samples)
    %     %|| [ 6 30 ]
    %     cosmo_disp(ds.a.fdim)
    %     %|| .labels
    %     %||   { 'i'  'j'  'k' }
    %     %|| .values
    %     %||   { [ 1 2 3 ]  [ 1 2 ]  [ 1 2 3 4 5 ] }
    %     %
    %     % flatten the dataset
    %     [unfl,labels,values]=cosmo_unflatten(ds);
    %     %
    %     % the unflattened dataset is of size 6 x 3 x 2 x 5
    %     size(unfl)
    %     %|| [ 6 3 2 5 ]
    %     cosmo_disp(labels)
    %     %|| { 'i'  'j'  'k' }
    %     cosmo_disp(values)
    %     %|| { [ 1 2 3 ]  [ 1 2 ]  [ 1 2 3 4 5 ] }
    %
    %     % ds is a small dataset with 2 classes
    %     ds=cosmo_synthetic_dataset();
    %     %
    %     % compute all (2x2) split-half correlation values
    %     res=cosmo_correlation_measure(ds,'output','raw',...
    %                                       'post_corr_func',[]);
    %     cosmo_disp(res)
    %     %|| .samples
    %     %||   [  0.363
    %     %||     -0.404
    %     %||     -0.447
    %     %||      0.606 ]
    %     %|| .sa
    %     %||   .half1
    %     %||     [ 1
    %     %||       2
    %     %||       1
    %     %||       2 ]
    %     %||   .half2
    %     %||     [ 1
    %     %||       1
    %     %||       2
    %     %||       2 ]
    %     %|| .a
    %     %||   .sdim
    %     %||     .labels
    %     %||       { 'half1'  'half2' }
    %     %||     .values
    %     %||       { [ 1    [ 1
    %     %||           2 ]    2 ] }
    %     %
    %     % reshape the correlations into a square matrix
    %     [unfl,labels,values]=cosmo_unflatten(res,1);
    %     %
    %     % yields a 2x2x1 matrix (matlab omits the last, singleton dimension)
    %     cosmo_disp(unfl)
    %     %|| [  0.363    -0.447
    %     %||   -0.404     0.606 ]
    %     %
    %     cosmo_disp(labels)
    %     %|| { 'half1'  'half2' }
    %     %
    %     cosmo_disp(values)
    %     %|| { [ 1    [ 1
    %     %||     2 ]    2 ] }
    %
    %
    % Notes:
    %   - A typical use case is mapping an fMRI or MEEG dataset struct
    %     back to a 3D or 4D array.
    %   - This function is the inverse of cosmo_flatten.
    %
    % See also: cosmo_flatten, cosmo_map2fmri, cosmo_map2meeg
    %
    % #   For CoSMoMVPA's copyright information and license terms,   #
    % #   see the COPYING file distributed with CoSMoMVPA.           #

    if nargin < 2 || isempty(dim)
        dim = 2;
    end

    if ~(isnumeric(dim) && isscalar(dim))
        error('second argument must be numeric');
    end

    cosmo_check_dataset(ds);

    defaults = struct();
    defaults.set_missing_to = 0;
    defaults.matrix_labels = cell(0);
    opt = cosmo_structjoin(defaults, varargin);

    switch dim
        case 1
            cosmo_isfield(ds, {'a.sdim', 'samples', 'sa'}, true);
            do_transpose = true;
            a_dim = ds.a.sdim;
            attr = ds.sa;

        case 2
            cosmo_isfield(ds, {'a.fdim', 'samples', 'fa'}, true);
            do_transpose = false;
            a_dim = ds.a.fdim;
            attr = ds.fa;

        otherwise
            error('dim must be 1 or 2');
    end

    samples = ds.samples;
    if do_transpose
        samples = samples';
        a_dim.values = cellfun(@transpose, a_dim.values, ...
                               'UniformOutput', false);
    end

    [arr, dim_labels, dim_values] = unflatten_features(samples, ...
                                                       a_dim, attr, opt);

    if do_transpose
        arr = shiftdim(arr, 1);
        dim_values = cellfun(@transpose, dim_values, 'UniformOutput', false);
    end

function [arr, dim_labels, dim_values] = unflatten_features(samples, ...
                                                            a_dim, attr, opt)
    nsamples = size(samples, 1);
    dim_labels = a_dim.labels;
    dim_values = a_dim.values;

    % number of feature dimensions
    ndim = numel(dim_labels);

    % get sub indices for each feature dimension
    sub_indices = cellfun(@(x)attr.(x), dim_labels, 'UniformOutput', false);

    % get dimension values
    [dim_sizes, dim_values] = get_dim_sizes(dim_values, dim_labels, opt);

    max_indices = cellfun(@max, sub_indices);
    too_small_dim = find(max_indices(:) > dim_sizes(:), 1);
    if ~isempty(too_small_dim)
        error(['dimension with label %s has %d dimension labels,'...
               'but attribute indexes up to %d'], ...
              dim_labels{too_small_dim}, dim_sizes(too_small_dim), ...
              max_indices(too_small_dim));
    end

    % allocate space for output - one cell per sample
    arr_cell = cell(1, nsamples);

    % convert sub indices to linear indices
    if ndim == 1
        lin_indices = sub_indices{1};
    else
        lin_indices = sub2ind(dim_sizes, sub_indices{:});
    end

    unq_lin_indices = unique(lin_indices);

    if numel(lin_indices) ~= numel(unq_lin_indices)
        h = histc(lin_indices, unq_lin_indices);
        duplicate = unq_lin_indices(find(h > 1, 1));
        two_duplicate_pos = find(lin_indices == duplicate, 2);

        error('Duplicate features at #%d and #%d', ...
              two_duplicate_pos(1), two_duplicate_pos(2));
    end

    % allocate space in 'ndim'-space for each sample,
    % but with a first singleton dimension as that one
    % is used for the samples
    arr_dim = zeros([1, dim_sizes]);

    % process each sample
    for k = 1:nsamples
        % make empty
        arr_dim(:) = opt.set_missing_to;

        % assign to proper location
        arr_dim(lin_indices) = samples(k, :);

        % store result for this sample
        arr_cell{k} = arr_dim;
    end

    % combine all samples
    arr = cat(1, arr_cell{:});

function [dim_sizes, dim_values] = get_dim_sizes(dim_values, dim_labels, opt)
    ndim = numel(dim_labels);
    if numel(dim_values) ~= ndim
        error(['size mismatch between number of dimension values (%d)'...
               'and dimension labels (%d)'], ...
              numel(dim_values), ndim);
    end
    % number of elements in each dimension
    dim_sizes = zeros(1, ndim);

    % go over dimensions
    for dim = 1:ndim
        dim_label = dim_labels{dim};
        dim_value = dim_values{dim};
        if cosmo_match({dim_label}, opt.matrix_labels)
            dim_size = size(dim_value, 2);
        else
            if ~isvector(dim_value)
                error(['Label ''%s'' (dimension %d) must be a vector, '...
                       'because it was not specified as a matrix '...
                       'dimension in the ''matrix_labels'' option'], ...
                      dim_label, dim);
            end
            dim_size = numel(dim_value);
            dim_values{dim} = dim_value(:)'; % make it a row vector
        end

        dim_sizes(dim) = dim_size;
    end