cosmo dim insert skl

function ds=cosmo_dim_insert(ds,dim,index,labels,values,attr,varargin)
% insert a dataset dimension
%
% ds_result=cosmo_dim_insert(ds,dim,index,labels,values,attr,...)
%
% Inputs:
%   ds                  dataset struct
%   dim                 dimension along which dimensions must be inserted,
%                       1=samples, 2=features
%   index               position at which dimension must be inserted,
%                       in .a.sdim (if dim==1) or .a.fdim (if dim==2)
%   labels              dimension labels
%   values              dimension values
%   attr                cell with values for .sa or .fa, or a struct with
%                       the fields that are in labels
%   'matrix_labels',m   (optional) any label for which the corresponding
%                       value is a matrix must be an element of the
%                       cellstring m. Currently this applies to the 'pos'
%                       field in MEEG source data
%
% Output:
%   ds_result           dataset struct with dim_labels removed from
%                       .a.{fdim,sdim} and .{fa,sa}.
%
% Example:
%     % generate tiny fmri dataset
%     ds=cosmo_synthetic_dataset();
%     %
%     % remove first two feature dimensions ('i' and 'j')
%     dim_labels=ds.a.fdim.labels(1:2);
%     dim_values=ds.a.fdim.values(1:2);
%     dsr=cosmo_dim_remove(ds,dim_labels);
%     %
%     % add them back in
%     ds_humpty=cosmo_dim_insert(dsr,2,1,dim_labels,dim_values,...
%                                             {ds.fa.i,ds.fa.j});
%     %
%     % the output is the same as the original dataset
%     isequal(ds,ds_humpty)
%
% Notes:
%   - this is a utility function, mostly intended for use by other
%     functions
%   - this function does not check for duplicate dimensions
%
% #   For CoSMoMVPA's copyright information and license terms,   #
% #   see the COPYING file distributed with CoSMoMVPA.           #


    defaults.matrix_labels=cell(0);
    defaults.check_dataset=true;
    opt=cosmo_structjoin(defaults,varargin{:});

    prefixes='sf';
    prefix=prefixes(dim);
    attr_name=[prefix 'a'];
    dim_name=[prefix 'dim'];

    ds=ensure_has_xa_xdim(ds,dim,attr_name,dim_name);

    % get values in proper size
    attr_values=get_attr_values(labels,attr);
    dim_values=get_dim_values(labels,values,dim,opt);

    if ~iscellstr(labels)
        error('labels must be a cell with strings');
    end

    ds.a.(dim_name).labels=insert_elements(ds.a.(dim_name).labels, ...
                                                index, labels, dim);
    ds.a.(dim_name).values=insert_elements(ds.a.(dim_name).values, ...
                                                index, dim_values, dim);
    for j=1:numel(labels)
        label=labels{j};
        ds.(attr_name).(label)=attr_values{j};
    end

    if opt.check_dataset
        cosmo_check_dataset(ds);
    end

function ds=ensure_has_xa_xdim(ds,dim,attr_name,dim_name)
    if ~cosmo_isfield(ds,attr_name)
        ds.(attr_name)=struct();
    end

    if ~cosmo_isfield(ds,['a.' dim_name]);
        ds.a.(dim_name)=struct();
        empty_size=[0 0];
        empty_size(dim)=1;
        ds.a.(dim_name).labels=cell(empty_size);
        ds.a.(dim_name).values=cell(empty_size);
    end

function ys=insert_elements(xs,i,y,dim)
    % insets y in xs at position i
    n=numel(xs);

    if i<-n || i>(n+1)
        error('position index %d must be in range 1..%d, or -%d..-1',...
                        i,n+1,n+1);
    end

    if i<=0
        i=n+i+1;
    end

    xs_col=xs(:);
    ys=[xs_col(1:(i-1));y(:);xs_col(i:end)];
    if dim==1
        ys=ys';
    end

function dim_values=get_dim_values(labels,values,dim,opt)
    matrix_labels=opt.matrix_labels;

    n=numel(labels);

    if ~iscell(labels)
        error('labels argument must be a cell');
    end

    if ~iscell(values)
        error('values argument must be a cell');
    end


    if numel(values)~=n
        error('size mismatch between labels and values');
    end

    dim_values_shape=[1 1];
    dim_values_shape(3-dim)=n;

    dim_values=cell(dim_values_shape);
    for j=1:n
        label=labels{j};
        value=values{j};

        if ~cosmo_match({label},matrix_labels)
            sz=size(value);
            if ~any(sz==1)
                error(['dim value for %s must be a vector, because it '...
                            'is not set in the matrix_labels option'],...
                            label);
            end

            needs_transpose=sz(dim)==1;

            if needs_transpose
                value=value';
            end
        end
        dim_values{j}=value;
    end





function values=get_attr_values(labels,attr)
    % get elements for .sa or .fa. attr can either be a cell or a struct
    if isstruct(attr)
        values=cell(size(labels));
        for k=1:numel(labels)
            label=labels{k};
            if ~isfield(attr,label)
                error('missing field %s', label);
            end
            values{k}=attr.(label);
        end
    elseif iscell(attr)
        values=attr;
    else
        error('illegal attr value: must be struct or cell');
    end

    n=numel(labels);

    if numel(values)~=n
        error(['number of values (%d) does not match the number of '...
                    'labels'],numel(values),n);
    end