cosmo split skl

function ds_splits = cosmo_split(ds, split_by, dim, check)
    % splits a dataset by unique values in (a) sample or feature attribute(s).
    %
    % cosmo_split(ds, split_by[, dim])
    %
    % Inputs:
    %   ds          dataset struct
    %   split_by    fieldname for the sample (if dim==1) or feature (if dim==2)
    %               attribute by which the dataset is split.
    %               It can also be a cell with fieldnames, in which case the
    %               dataset is split by the set of combinations from these
    %               fieldnames.
    %   dim         1 (split by samples; default) or 2 (split by features).
    %   check       boolean indicating whether the input ds is checked as being
    %               a proper dataset (default: true)
    %
    % Returns:
    %   ds_splits   1xP cell, if there are P unique values for the (set of)
    %               attribute(s) indicated by split_by and dim.
    %
    % Examples:
    %     ds=cosmo_synthetic_dataset();
    %     %
    %     % split by targets
    %     splits=cosmo_split(ds,'targets');
    %     cosmo_disp(splits{2}.sa);
    %     %|| .targets
    %     %||   [ 2
    %     %||     2
    %     %||     2 ]
    %     %|| .chunks
    %     %||   [ 1
    %     %||     2
    %     %||     3 ]
    %     %
    %     % split by chunks
    %     splits=cosmo_split(ds,'chunks');
    %     cosmo_disp(splits{3}.sa);
    %     %|| .targets
    %     %||   [ 1
    %     %||     2 ]
    %     %|| .chunks
    %     %||   [ 3
    %     %||     3 ]
    %     %
    %     % split by chunks and targets
    %     splits=cosmo_split(ds,{'chunks','targets'});
    %     cosmo_disp(splits{5}.sa);
    %     %|| .targets
    %     %||   [ 1 ]
    %     %|| .chunks
    %     %||   [ 3 ]
    %
    %     % take an MEEG time-freq dataset, and split by time and channel
    %     ds=cosmo_synthetic_dataset('type','timefreq','size','big');
    %     %
    %     % dataset has 11 channels, 7 frequencies and 5 time points
    %     cosmo_disp(ds.fa)
    %     %|| .chan
    %     %||   [ 1         2         3  ...  304       305       306 ]@1x10710
    %     %|| .freq
    %     %||   [ 1         1         1  ...  7         7         7 ]@1x10710
    %     %|| .time
    %     %||   [ 1         1         1  ...  5         5         5 ]@1x10710
    %     %
    %     % split by time and frequency. Since splitting is done on the feature
    %     % dimension, the third argument (with value 2) is mandatory
    %     splits=cosmo_split(ds,{'time','freq'},2);
    %     % there are 7 * 5 = 35 splits, each with 11 features
    %     numel(splits)
    %     %|| 35
    %     cosmo_disp(cellfun(@(x) size(x.samples,2),splits))
    %     %|| [ 306       306       306  ...  306       306       306 ]@1x35
    %     cosmo_disp(splits{18}.fa)
    %     %|| .chan
    %     %||   [ 1         2         3  ...  304       305       306 ]@1x306
    %     %|| .freq
    %     %||   [ 4         4         4  ...  4         4         4 ]@1x306
    %     %|| .time
    %     %||   [ 3         3         3  ...  3         3         3 ]@1x306
    %     %
    %     % using cosmo_stack brings the split elements together again
    %     humpty_dumpty=cosmo_stack(splits,2);
    %     cosmo_disp(humpty_dumpty.fa)
    %     %|| .chan
    %     %||   [ 1         2         3  ...  304       305       306 ]@1x10710
    %     %|| .freq
    %     %||   [ 1         1         1  ...  7         7         7 ]@1x10710
    %     %|| .time
    %     %||   [ 1         1         1  ...  5         5         5 ]@1x10710
    %
    % Note:
    %   - This function is like the inverse of cosmo_stack; if
    %
    %       >> ds_splits=cosmo_split(ds, split_by, dim),
    %
    %     produces output (i.e., does not throw an error), then using
    %
    %       >> ds_humpty_dumpty=cosmo_stack(ds_splits,dim)
    %
    %     means that ds and ds_humpty_dumpty contain the same data, except that
    %     the order of the data (in the rows [columns] of .samples, or
    %     .sa [.fa]) may be different if dim==1 [dim==2].
    %
    %
    % See also: cosmo_stack, cosmo_slice
    %
    % #   For CoSMoMVPA's copyright information and license terms,   #
    % #   see the COPYING file distributed with CoSMoMVPA.           #

    % set default dim & check input
    if nargin < 4
        check = true;
    end

    if nargin < 3 || isempty(dim)
        dim = 1;
    elseif dim ~= 1 && dim ~= 2
        error('dim should be 1 or 2');
    end

    if check
        cosmo_check_dataset(ds);
    end

    % if empty split return just the dataset itself
    if isempty(split_by)
        ds_splits = {ds};
        return
    end

    % ensure that split_by is a cell of strings
    if ischar(split_by)
        split_by = {split_by};
    elseif ~iscellstr(split_by)
        error('split_by must be string or cell of strings');
    end

    % get values to split by
    split_values = get_attr_values(ds, split_by, dim);

    % get indices of unique parts
    split_idxs = cosmo_index_unique(split_values);

    % allocate space for output
    n = numel(split_idxs);
    ds_splits = cell(1, n);

    % slice for each unique part
    for k = 1:n
        ds_splits{k} = cosmo_slice(ds, split_idxs{k}, dim, false);
    end

function values = get_attr_values(ds, split_by, dim)
    attrs_fns = {'sa', 'fa'};
    attrs_fn = attrs_fns{dim};

    cosmo_isfield(ds, attrs_fn, true); % check presence
    attrs = ds.(attrs_fn);

    n = numel(split_by);
    values = cell(n, 1);
    for k = 1:n
        key = split_by{k};
        % check field is present
        cosmo_isfield(attrs, key, true);

        value = attrs.(key);
        if ~is_1d(value)
            error('value for ''.%s.%s'' must be 1D', attrs_fn, key);
        end
        values{k} = value;
    end

function tf = is_1d(x)
    tf = sum(size(x) > 1) <= 1;