cosmo slice skl

function ds = cosmo_slice(ds, to_select, dim, type_or_check)
    % Slice a dataset by samples (the default) or features
    %
    % sliced_ds=cosmo_slice(ds, elements_to_select[, dim][check|'struct'])
    %
    % Inputs:
    %   ds                    One of:
    %                         - dataset struct to be sliced, with PxQ field
    %                           .samples and optionally fields .fa, .sa and .a.
    %                         - PxQ cell
    %                         - PxQ logical or numeric array
    %   elements_to_select    either a binary mask or a list of indices of
    %                         the samples (if dim==1) or features (if dim==2)
    %                         to select. If a binary mask then the number of
    %                         elements should match the size of ds in the
    %                         dim-th dimension.
    %   dim                   Slicing dimension: along samples (dim==1) or
    %                         features (dim==2). (default: 1).
    %   check                 Boolean that indicates that if ds is a dataset,
    %                         whether it should be checked for proper
    %                         structure. (default: true).
    %   'struct'              If provided and ds is a struct, then
    %                         all fields of ds, which are assumed to be cell
    %                         or arrays,  are sliced.
    %
    % Output:
    %   sliced_ds             - If ds is a cell or array then sliced_ds is
    %                           the result of slicing ds along the dim-th
    %                           dimension. The result is of size NxQ (if
    %                           dim==1) or PxN (if dim==2), where N is the
    %                           number of non-zero values in
    %                           elements_to_select.
    %                         - If ds is a dataset struct then
    %                           sliced_ds.samples is the result of slicing
    %                           ds.samples.
    %                           If present, fields .sa (if dim==1) or
    %                           .fa (dim==2) are sliced as well.
    %                         - when ds is a struct and the 'struct' option was
    %                           given, then all fields in ds are sliced.
    %
    % Examples:
    %     % make a simple dataset
    %     ds=struct();
    %     ds.samples=reshape(1:12,4,3); % 4 samples, 3 features
    %     % sample attributes
    %     ds.sa.chunks=[1 1 2 2]';
    %     ds.sa.targets=[1 2 1 2]';
    %     % feature attributes
    %     ds.fa.i=[3 8 13];
    %     ds.fa.roi={'vt','loc','v1'};
    %     % dataset attributes
    %     ds.a.note='an example';
    %     % display dataset
    %     cosmo_disp(ds);
    %     %|| .samples
    %     %||   [ 1         5         9
    %     %||     2         6        10
    %     %||     3         7        11
    %     %||     4         8        12 ]
    %     %|| .sa
    %     %||   .chunks
    %     %||     [ 1
    %     %||       1
    %     %||       2
    %     %||       2 ]
    %     %||   .targets
    %     %||     [ 1
    %     %||       2
    %     %||       1
    %     %||       2 ]
    %     %|| .fa
    %     %||   .i
    %     %||     [ 3         8        13 ]
    %     %||   .roi
    %     %||     { 'vt'  'loc'  'v1' }
    %     %|| .a
    %     %||   .note
    %     %||     'an example'
    %     %
    %     % (snippet) select samples (row) in a dataset
    %     % ds is a dataset struct
    %     sample_ids=[3 2];
    %     % select third and second sample (in that order)
    %     sliced_ds=cosmo_slice(ds,sample_ids,1);
    %     %
    %     cosmo_disp(sliced_ds);
    %     %|| .samples
    %     %||   [ 3         7        11
    %     %||     2         6        10 ]
    %     %|| .sa
    %     %||   .chunks
    %     %||     [ 2
    %     %||       1 ]
    %     %||   .targets
    %     %||     [ 1
    %     %||       2 ]
    %     %|| .fa
    %     %||   .i
    %     %||     [ 3         8        13 ]
    %     %||   .roi
    %     %||     { 'vt'  'loc'  'v1' }
    %     %|| .a
    %     %||   .note
    %     %||     'an example'
    %     %
    %     % select third and second feature (in that order)
    %     sliced_ds=cosmo_slice(ds, [3 2], 2);
    %     cosmo_disp(sliced_ds);
    %     %|| .samples
    %     %||   [  9         5
    %     %||     10         6
    %     %||     11         7
    %     %||     12         8 ]
    %     %|| .sa
    %     %||   .chunks
    %     %||     [ 1
    %     %||       1
    %     %||       2
    %     %||       2 ]
    %     %||   .targets
    %     %||     [ 1
    %     %||       2
    %     %||       1
    %     %||       2 ]
    %     %|| .fa
    %     %||   .i
    %     %||     [ 13         8 ]
    %     %||   .roi
    %     %||     { 'v1'  'loc' }
    %     %|| .a
    %     %||   .note
    %     %||     'an example'
    %     %
    %     % using a logical mask, select features with odd value for .i
    %     msk=mod(ds.fa.i,2)==1;
    %     disp(msk)
    %     %|| [1 0 1]
    %     sliced_ds=cosmo_slice(ds, msk, 2);
    %     cosmo_disp(sliced_ds);
    %     %|| .samples
    %     %||   [ 1         9
    %     %||     2        10
    %     %||     3        11
    %     %||     4        12 ]
    %     %|| .sa
    %     %||   .chunks
    %     %||     [ 1
    %     %||       1
    %     %||       2
    %     %||       2 ]
    %     %||   .targets
    %     %||     [ 1
    %     %||       2
    %     %||       1
    %     %||       2 ]
    %     %|| .fa
    %     %||   .i
    %     %||     [ 3        13 ]
    %     %||   .roi
    %     %||     { 'vt'  'v1' }
    %     %|| .a
    %     %||   .note
    %     %||     'an example'
    %
    %     % slice all fields in a struct
    %     s=struct();
    %     s.a_field=[1 2 3; 4 5 6];
    %     s.another_field={'this','is','fun'};
    %     cosmo_disp(s);
    %     %|| .a_field
    %     %||   [ 1         2         3
    %     %||     4         5         6 ]
    %     %|| .another_field
    %     %||   { 'this'  'is'  'fun' }
    %     %
    %     % select first, third, third, and second column (dim=2)
    %     t=cosmo_slice(s, [1 3 3 2], 2, 'struct');
    %     cosmo_disp(t);
    %     %|| .a_field
    %     %||   [ 1         3         3         2
    %     %||     4         6         6         5 ]
    %     %|| .another_field
    %     %||   { 'this'  'fun'  'fun'  'is' }
    %
    %
    % Notes:
    %   - do_check=false may be preferred for slice-intensive operations such
    %     as when used in searchlights
    %   - this function does not support arrays with more than two dimensions.
    %
    % #   For CoSMoMVPA's copyright information and license terms,   #
    % #   see the COPYING file distributed with CoSMoMVPA.           #

    % deal with 2, 3, or 4 input arguments
    if nargin < 3 || isempty(dim)
        dim = 1;
    end
    if nargin < 4 || isempty(type_or_check)
        type_or_check = true;
    end

    if iscell(ds) || isnumeric(ds) || islogical(ds)
        ds = slice_array(ds, to_select, dim, type_or_check);
    elseif isstruct(ds)
        if strcmp(type_or_check, 'struct')
            ds = slice_struct(ds, to_select, dim, type_or_check);
        else
            if ~isfield(ds, 'samples')
                error(['Expected dataset struct. To slice ordinary '...
                       'structs use "struct" as last argument']);
            end

            if type_or_check
                % check kosherness
                cosmo_check_dataset(ds);
            end

            dim_size = size(ds.samples, dim);

            % slice the samples
            ds.samples = slice_array(ds.samples, to_select, dim, type_or_check);

            % now deal with either feature or sample attributes
            attr_fns = {'sa', 'fa'};
            attr_fn = attr_fns{dim}; % fieldname of attribute to slice

            if isfield(ds, attr_fn)
                ds.(attr_fn) = slice_struct(ds.(attr_fn), to_select, ...
                                            dim, type_or_check, dim_size);
            end
        end
    else
        error('Illegal input: expected cell, array or struct');
    end

    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    % helper functions
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

function y = slice_struct(x, to_select, dim, do_check, expected_size)
    if nargin < 5
        expected_size = NaN;
    end

    y = struct();
    fns = fieldnames(x);
    for k = 1:numel(fns)
        fn = fns{k};
        v = x.(fn);

        v_size = size(v, dim);

        % ensure all input sizes are the same
        if isnan(expected_size)
            expected_size = v_size;
        elseif v_size ~= expected_size
            error(['Size mismatch for %s: expected %d but found %d', ...
                   ' elements in dimension %d'], ...
                  fn, v_size, expected_size, dim);
        end

        y.(fn) = slice_array(v, to_select, dim, do_check);
    end

function y = slice_array(x, to_select, dim, do_check)
    if do_check
        check_size(x, to_select, dim);
        if ~isscalar(dim) || ~isnumeric(dim)
            error('dim must be 1 or 2');
        end
    end

    if dim == 1
        y = x(to_select, :);
    elseif dim == 2
        y = x(:, to_select);
    else
        error('dim must be 1 or 2');
    end

function check_size(x, to_select, dim)
    if islogical(to_select) && ...
                size(x, dim) ~= numel(to_select)
        % be a bit more strict than matlab - binary array must have
        % exactly the correct size
        error('Logical mask should have %d elements, found %d', ...
              size(x, dim), numel(to_select));
    end

    if numel(size(x)) ~= 2
        error('Only 2D arrays are allowed');
    end

    if sum(size(to_select) > 1) > 1
        error('elements to select should be in vector');
    end