test chunkize

function test_suite = test_chunkize
    % tests for cosmo_chunkize
    %
    % #   For CoSMoMVPA's copyright information and license terms,   #
    % #   see the COPYING file distributed with CoSMoMVPA.           #
    try % assignment of 'localfunctions' is necessary in Matlab >= 2016
        test_functions = localfunctions();
    catch % no problem; early Matlab versions can use initTestSuite fine
    end
    initTestSuite;

function test_chunkize_basis
    ds = cosmo_synthetic_dataset('type', 'timelock', 'nreps', 8);
    ds.sa.chunks = reshape(repmat(1:8, 6, 1), [], 1);
    ds = cosmo_slice(ds, randperm(48));

    chunks = cosmo_chunkize(ds, 8);
    assertEqual(chunks, ds.sa.chunks);

    for j = 1:2:7
        chunks = cosmo_chunkize(ds, j);
        eq_chunks = bsxfun(@eq, chunks, chunks');
        eq_ds = bsxfun(@eq, ds.sa.chunks, ds.sa.chunks');

        m = eq_ds & ~eq_chunks;
        assert(~any(m(:)));
    end

    assertExceptionThrown(@()cosmo_chunkize(ds, 9), '');
    ds = rmfield(ds.sa, 'chunks');
    assertExceptionThrown(@()cosmo_chunkize(ds, 2), '');

function test_chunkize_imbalance()
    ds = struct();
    ds.samples = (1:5)';
    assertExceptionThrown(@()cosmo_chunkize(ds, 2), '');
    ds.sa.chunks = 2 + [1 1 2 2 2]';
    assertExceptionThrown(@()cosmo_chunkize(ds, 2), '');
    ds.sa.targets = 10 + [1 2 1 2 2]';
    assertExceptionThrown(@()cosmo_chunkize(ds, 3), '');

    count = 2;
    res = cosmo_chunkize(ds, count);
    assert_chunkize_ok(ds, res, count);

    ds2 = cosmo_stack({ds, ds});
    res2 = cosmo_chunkize(ds2, count);
    assert_chunkize_ok(ds2, res2, count);

function test_all_unique_chunks_tiny()
    ds = struct();
    ds.samples = (1:5)';
    ds.sa.targets = 2 + [1 1 2 2 2]';
    ds.sa.chunks = 10 + [1 2 3 4 5]';

    for count = 2:5
        res = cosmo_chunkize(ds, count);
        assert_chunkize_ok(ds, res, count);
    end

    assertExceptionThrown(@()cosmo_chunkize(ds, 6), '');

function test_chunkize_very_unbalanced_chunks_big()
    % all chunks are unique, want a similar number of targets in each
    % output chunk
    ds = cosmo_synthetic_dataset('nreps', 6, 'ntargets', 5);

    nsamples = size(ds.samples, 1);
    ds.sa.chunks(:) = repmat(1:nsamples / 10, 1, 10);

    n_combis = max(ds.sa.chunks) * max(ds.sa.targets);

    targets = ds.sa.targets;
    n_swap = 5;

    while true
        rp = randperm(nsamples);
        ds.sa.targets = targets;
        ds.sa.targets(rp(1:n_swap)) = ds.sa.targets(rp(n_swap:-1:1));
        idxs = cosmo_index_unique({ds.sa.targets, ds.sa.chunks});
        n = cellfun(@numel, idxs);
        if min(n) >= 1 && max(n) <= 3 && std(n) < .1 && numel(n) == n_combis
            % not too unbalanced
            break
        end
    end

    nchunks = ceil(3 + rand() * 4);
    res = cosmo_chunkize(ds, nchunks);
    assert_chunkize_ok(ds, res, nchunks);

function test_chunkize_slight_unbalanced_chunks_big()
    % all chunks are unique, want a similar number of targets in each
    % output chunk
    ds = cosmo_synthetic_dataset('nreps', 6, 'ntargets', 5);

    nsamples = size(ds.samples, 1);
    ds.sa.chunks(:) = repmat(1:nsamples / 2, 1, 2);
    ds.sa.targets(1:5) = ds.sa.targets(2:6); % slight imbalance

    nchunks = ceil(rand() * 5);
    res = cosmo_chunkize(ds, nchunks);
    assert_chunkize_ok(ds, res, nchunks);

function test_chunkize_all_unique_independent_chunks()
    % each sample has its own unique chunk value
    ds = cosmo_synthetic_dataset('ntargets', 2, 'nchunks', 6 * 6);
    nsamples = size(ds.samples, 1);
    ds.sa.chunks(:) = ceil(rand() * 10) + (1:nsamples);
    ds.sa.targets = ds.sa.targets(randperm(nsamples));

    nchunks_candidates = [1 2 3 4 6 12 18];
    for nchunks = nchunks_candidates
        chunks = cosmo_chunkize(ds, nchunks);
        assert_chunkize_ok(ds, chunks, nchunks);

        idxs = cosmo_index_unique([ds.sa.targets chunks]);
        n = cellfun(@numel, idxs);

        % require full balance
        assert(all(n(1) == n(2:end)));
    end

function test_chunkize_dependent_balanced_chunks()
    % each combination of chunks and targets occurs equally often
    ntargets = ceil(2 + rand() * 4);
    nreps = ceil(2 + rand() * 4);
    nchunks = 36;
    ds = cosmo_synthetic_dataset('ntargets', ntargets, ...
                                 'nchunks', nchunks, 'nreps', nreps);
    nsamples = size(ds.samples, 1);
    ds = cosmo_slice(ds, randperm(nsamples));

    rep_idxs = cosmo_index_unique({ds.sa.chunks, ds.sa.targets});
    assert(all(cellfun(@numel, rep_idxs) == nreps));

    nchunks_candidates = [1 2 3 4 6 12 18];
    for nchunks = nchunks_candidates
        chunks = cosmo_chunkize(ds, nchunks);
        assert_chunkize_ok(ds, chunks, nchunks);

        idxs = cosmo_index_unique([ds.sa.targets chunks]);
        n = cellfun(@numel, idxs);

        % require full balance
        assert(all(n(1) == n(2:end)));
    end

function assert_chunkize_ok(src_ds, chunks, count)
    % number of items must match input dataset
    assertEqual(numel(src_ds.sa.chunks), numel(chunks));

    % must be balanced
    assert_chunks_targets_balanced(src_ds, chunks);

    % cannot have double dipping
    assert_no_double_dipping(src_ds, chunks);

    % must have the proper number of chunks
    assertEqual(numel(unique(chunks)), count);

    assert_chunks_targets_nonzero(src_ds, chunks);

function assert_chunks_targets_balanced(src_ds, chunks)
    idxs = cosmo_index_unique([src_ds.sa.targets chunks]);

    n = cellfun(@numel, idxs);

    % cannot test for 'optimal' balance due to combinatorial explosion;
    % this is a decent approach to make sure that chunks are not too
    % imbalanced
    assert(std(n) <= 1.5);
    assert(min(n) + 2 >= max(n));

function assert_chunks_targets_nonzero(src_ds, chunks)
    [unused, unused, t_idxs] = unique(src_ds.sa.targets);
    [unused, unused, c_idxs] = unique(chunks);

    nt = max(t_idxs);
    nc = max(c_idxs);
    h = zeros(nt, nc);

    ns = numel(chunks);
    for k = 1:ns
        t = t_idxs(k);
        c = c_idxs(k);
        h(t, c) = h(t, c) + 1;
    end

    assert(all(max(h, [], 1) > 0));
    assert(all(max(h, [], 2) > 0));

function assert_no_double_dipping(src_ds, chunks)
    % samples that were in different chunks in src_ds must not be in the
    % same chunk in trg_ds
    [unq_src, unused, src_ids] = unique(src_ds.sa.chunks);
    [unq_trg, unused, trg_ids] = unique(chunks);

    n_src = numel(unq_src);
    n_trg = numel(unq_trg);
    n_samples = numel(src_ds.sa.chunks);
    chunk_count = zeros(n_src, n_trg);

    for k = 1:n_samples
        i = src_ids(k);
        j = trg_ids(k);
        chunk_count(i, j) = chunk_count(i, j) + 1;
    end

    assert(all(sum(chunk_count > 0, 2) == 1));