cosmo randomize targets skl

function [randomized_targets, permutation] = cosmo_randomize_targets(ds, varargin)
    % provides randomized target labels
    %
    % randomized_targets=cosmo_randmize_targets(ds[,'seed',seed)
    %
    % Inputs:
    %   ds                    dataset struct with fields .sa.targets and
    %                         .sa.chunks
    %   'seed', seed          (optional) if provided, use this seed value for
    %                         pseudo-random number generation
    %
    %
    % Outputs:
    %   randomized_targets    P x 1 with randomized targets
    %                         If ds defines a repeated-measures design (which
    %                         requires that each chunk has the same set of
    %                         unique targets), then targets are randomized
    %                         separately for each chunk.
    %                         Otherwise (when each chunk is associated with
    %                         exactly one sample, i.e. all samples are
    %                         independent), the targets are randomized
    %                         without considering the chunk values.
    %   permutation           P x 1 with indices of permutation. It holds that
    %                         randomized_targets == ds.sa.targets(permutation).
    %
    % Example:
    %     % generate tiny dataset with 15 chunks, each with two targets
    %     ds=cosmo_synthetic_dataset('nchunks',15);
    %     % show number of samples with targets 1 or 2
    %     histc(ds.sa.targets',1:2)
    %     %|| [15 15]
    %     % generate randomized targets
    %     rand_targets=cosmo_randomize_targets(ds);
    %     % the number of samples with targets 1 or 2 is the same ...
    %     histc(rand_targets',1:2)
    %     %|| [15 15]
    %     % ... but the targets are re-ordered
    %     all(ds.sa.targets==rand_targets)
    %     %|| false
    %     %
    %     % when using the 'seed' option, the output is deterministic
    %     % (multiple calls to this function always give the same output)
    %     rand_targets_deterministic=cosmo_randomize_targets(ds,'seed',314);
    %     rand_targets_deterministic'
    %     %|| [ 2 1 1 2 2 1 2 1 2 1 2 1 1 2 2 1 1 2 2 1 1 2 2 1 2 1 2 1 2 1 ]
    %
    % #   For CoSMoMVPA's copyright information and license terms,   #
    % #   see the COPYING file distributed with CoSMoMVPA.           #

    opt = cosmo_structjoin(varargin);

    [targets, chunks] = get_targets_and_chunks(ds);
    nsamples = numel(targets);

    [nunq_targets, unq_targets, target_idxs] = get_unique(targets);
    [nunq_chunks, unq_chunks] = get_unique(chunks);

    if nunq_chunks == nsamples
        % between-subject design
        rps = randperms_with_size(nsamples, opt);
        permutation = rps{1};
    else
        % within-subject design
        chunk_partition = cosmo_index_unique(chunks);

        % count number of samples in each chunk, and ensure that no target
        % is missing
        nchunks = numel(chunk_partition);
        samples_per_chunk = zeros(1, nchunks);
        for k = 1:nchunks
            sample_idxs = chunk_partition{k};
            h = histc(target_idxs(sample_idxs), 1:nunq_targets);
            if any(h == 0)
                i = find(h == 0, 1);
                error(['.sa.chunks and .sa.targets suggest a repeated '...
                       'measure design, but chunk %d has missing '...
                       'target %d'], ...
                      unq_chunks(k), unq_targets(i));
            end

            samples_per_chunk(k) = numel(sample_idxs);
        end

        % do permutation in each chunk separately
        rps = randperms_with_size(samples_per_chunk, opt);
        permutation = zeros(nsamples, 1);
        for k = 1:nchunks
            rp = rps{k};
            sample_idxs = chunk_partition{k};
            permutation(sample_idxs) = sample_idxs(rp);
        end
    end

    randomized_targets = targets(permutation);

function rps = randperms_with_size(sizes, opt)
    % helper function
    % Input: sizes is 1xN vector
    % Output: rps is 1xN cell; rps{k} is a random permutation of 1:sizes(k)
    %

    cum_size = sum(sizes);

    % single call to cosmo_rand, because this call is computationally
    % expensive
    if isfield(opt, 'seed')
        r = cosmo_rand(1, cum_size, 'seed', opt.seed);
    else
        r = cosmo_rand(1, cum_size);
    end

    n = numel(sizes);
    rps = cell(1, n);
    first_pos = 1;
    for k = 1:n
        last_pos = first_pos + sizes(k) - 1;

        % get sizes(k) random values
        r_part = r(first_pos:last_pos);

        % get sorting indices to get random permutation of 1:sizes(k)
        [unused, rps{k}] = sort(r_part);

        % for next iteration
        first_pos = last_pos + 1;
    end

function [n, unq, idxs] = get_unique(xs)
    [unq, unused, idxs] = unique(xs);
    n = numel(unq);

function [targets, chunks] = get_targets_and_chunks(ds)
    if ~isfield(ds, 'sa') || ...
             ~isfield(ds.sa, 'chunks') || ~isfield(ds.sa, 'targets')
        error('dataset must have .sa.chunks and .sa.targets');
    end
    targets = ds.sa.targets;
    chunks = ds.sa.chunks;