cosmo balance dataset skl

function [balanced_ds, idxs, classes] = cosmo_balance_dataset(ds, varargin)
    % sub-sample a dataset to have an equal number of samples for each target
    %
    % [balanced_ds,idxs,classes]=cosmo_balance_dataset(ds)
    %
    % Inputs:
    %   ds                      dataset struct with fields .samples,
    %                           .sa.targets and .sa.chunks. All values in
    %                           .sa.chunks must be different from each other.
    %   'sample_balancer', f    (optional)
    %                           function handle with signature
    %                               [idxs,classes]=f(targets,seed)
    %                           where idxs is a SxC vector with indices for C
    %                           classes and S targets per class. If omitted a
    %                           builtin function is used.
    %   'seed', s               (optional, default=1)
    %                           Seed to use for pseudo-random number generation
    %
    % Output:
    %   balanced_ds             dataset with a subset of the samples from ds
    %                           so that each target occurs equally often.
    %                           Selection is (by default) done in a
    %                           pseudo-determistic manner.
    %   idxs                    SxC vector indicating which
    %   classes                 Cx1 vector containing unique class labels
    %
    % Notes:
    %   - this function is to be used with MEEG datasets. it is not intended
    %     for fMRI data.
    %
    % #   For CoSMoMVPA's copyright information and license terms,   #
    % #   see the COPYING file distributed with CoSMoMVPA.           #

    defaults = struct();
    defaults.seed = 1;
    defaults.sample_balancer = [];

    opt = cosmo_structjoin(defaults, varargin{:});

    check_inputs(ds, opt);

    [idxs, classes] = select_subset(ds.sa.targets, opt);
    balanced_ds = cosmo_slice(ds, idxs(:), 1);

function [idxs, classes] = select_subset(targets, opt)
    sample_balancer = opt.sample_balancer;
    if isempty(sample_balancer)
        sample_balancer = @default_sample_balancer;
    end

    [idxs, classes] = sample_balancer(targets, opt.seed);

function [idxs, classes] = default_sample_balancer(targets, seed)
    [all_idxs, classes] = cosmo_index_unique(targets);
    class_counts = cellfun(@numel, all_idxs);

    min_count = min(class_counts);
    max_count = max(class_counts);

    nclasses = numel(class_counts);
    % invoke PRNG only once
    rand_vals = cosmo_rand(max_count, nclasses, 'seed', seed);

    idxs = zeros(min_count, nclasses);
    for k = 1:nclasses
        % set class_idxs to have values in the range 1:class_counts in
        % random order
        [unused, class_idxs] = sort(rand_vals(1:class_counts(k), k));

        % select min_count values from the indices in class_idxs
        idxs(:, k) = all_idxs{k}(class_idxs(1:min_count));
    end

function check_inputs(ds, opt)
    cosmo_check_dataset(ds);

    % chunks and targets must be present
    raise_exception = true;
    cosmo_isfield(ds, {'sa.targets', 'sa.chunks'}, raise_exception);

    chunks = ds.sa.chunks;
    if numel(sort(chunks)) ~= numel(unique(chunks))
        error(['All values in .sa.chunks must be unique. If '...
               '*and only if* all '...
               'observations in .samples can be assumed to be '...
               'independent, for a dataset ds you can set '...
               '  ds.sa.chunks(:)=1:numel(ds.sa.chunks,1)'...
               'to indicate independence. This assumption typically '...
               'only applies to M/EEG datasets; this function should '...
               'not be used for typical fMRI datasets']);
    end