test average samples

function test_suite = test_average_samples
    % tests for cosmo_average_samples
    %
    % #   For CoSMoMVPA's copyright information and license terms,   #
    % #   see the COPYING file distributed with CoSMoMVPA.           #
    try % assignment of 'localfunctions' is necessary in Matlab >= 2016
        test_functions = localfunctions();
    catch % no problem; early Matlab versions can use initTestSuite fine
    end
    initTestSuite;

function test_average_samples_
    ds = cosmo_synthetic_dataset();

    a = cosmo_average_samples(ds);

    assertElementsAlmostEqual(sort(a.samples), sort(ds.samples));
    assertElementsAlmostEqual(sort(a.samples(:, 3)), sort(ds.samples(:, 3)));

    a = cosmo_average_samples(ds, 'ratio', .5);

    assertElementsAlmostEqual(sort(a.samples), sort(ds.samples));
    assertElementsAlmostEqual(sort(a.samples(:, 3)), sort(ds.samples(:, 3)));

    % check wrong inputs
    aet = @(varargin)assertExceptionThrown(@() ...
                                           cosmo_average_samples(varargin{:}), '');

    aet(ds, 'ratio', .1);
    aet(ds, 'ratio', 3);
    aet(ds, 'ratio', .5, 'count', 2);

    ds.sa.chunks(:) = 1;
    a = cosmo_average_samples(ds, 'ratio', .5);
    cosmo_check_dataset(a);

    ds = cosmo_slice(ds, 3, 2);
    ns = size(ds.samples, 1);
    ds.samples = ds.sa.targets * 1000 + (1:ns)';

    a = cosmo_average_samples(ds, 'ratio', .5, 'nrep', 10);

    % no mixing of different targets
    delta = a.samples / 1000 - a.sa.targets;
    assertTrue(all(.00099 <= delta & delta < .05));
    assertElementsAlmostEqual(delta * 3000, round(delta * 3000));

    a = cosmo_average_samples(ds, 'count', 3, 'nrep', 10);
    % no mixing of different targets
    delta = a.samples / 1000 - a.sa.targets;
    assertTrue(all(.00099 <= delta & delta < .05));
    assertElementsAlmostEqual(delta * 3000, round(delta * 3000));

function test_average_samples_split_by
    plural_singular = {'targets', 'targets'; ...
                       'chunks', 'chunks'; ...
                       'subjects', 'subject'; ...
                       'modalities', 'modality' ...
                      };
    n_dim = size(plural_singular, 1);

    combis = cosmo_cartprod(repmat({{true, false}}, n_dim, 1)');
    for k = 1:size(combis, 1)
        combi = cell2mat(combis(k, :));
        opt = struct();
        opt.seed = 0; % truly random data
        for j = 1:n_dim
            count = ceil(rand() * 2 + 1);
            opt.(['n' plural_singular{j, 1}]) = count;
        end

        ds = cosmo_synthetic_dataset(opt);

        values = cell(n_dim, 1);
        for j = 1:n_dim
            if combi(j)
                values{j} = ds.sa.(plural_singular{j, 2});
            end
        end
        values = values(combi);
        if any(combi)
            [idx, unq_cell] = cosmo_index_unique(values);
        else
            idx = {1:(size(ds.samples, 1))};
        end
        n_avg = numel(idx);
        n_features = size(ds.samples, 2);
        expected_samples = zeros(n_avg, n_features);
        for m = 1:n_avg
            expected_samples(m, :) = mean(ds.samples(idx{m}, :), 1);
        end

        result = cosmo_average_samples(ds, ...
                                       'split_by', plural_singular(combi, 2));

        assertEqual(size(result.samples), size(expected_samples));
        delta = bsxfun(@minus, result.samples(:, 1), expected_samples(:, 1)');
        mapping = zeros(1, n_avg);
        for m = 1:n_avg
            [mn, mn_idx] = min(abs(delta(m, :)));
            assert(mn < 1e-5); % deal with rounding
            mapping(mn_idx) = m;
        end
        assertEqual(sort(mapping), 1:n_avg);

        result_perm = cosmo_slice(result, mapping);
        assertElementsAlmostEqual(result_perm.samples, expected_samples);

        pos = 0;
        for j = 1:n_dim
            if combi(j)
                pos = pos + 1;
                fn = plural_singular{j, 2};
                assertEqual(unq_cell{pos}, result_perm.sa.(fn));
            end
        end

        % check default result
        if isequal(plural_singular(combi), {'targets', 'chunks'})
            default_result = cosmo_average_samples(ds);
            assertEqual(result, default_result);
        end

    end

function test_average_samples_split_by_empty()
    ds = cosmo_synthetic_dataset('ntargets', ceil(rand() * 5 + 2), ...
                                 'nchunks', ceil(rand() * 5 + 2));
    result = cosmo_average_samples(ds, 'split_by', {});
    assertElementsAlmostEqual(result.samples, mean(ds.samples, 1));

function test_average_samples_exceptions
    aet = @(varargin)assertExceptionThrown(@() ...
                                           cosmo_average_samples(varargin{:}), '');
    ds = cosmo_synthetic_dataset('nreps', 5);

    aet([]);
    x = struct();
    x.samples = randn(4);
    aet(x);

    % illegal count
    aet(ds, 'count', 6);
    aet(ds, 'count', [2 2]);
    aet(ds, 'count', 3.5);
    aet(ds, 'count', 0);

    % illegal ratio
    aet(ds, 'ratio', 1.2);
    aet(ds, 'ratio', -0.2);
    aet(ds, 'ratio', [.5 .5]);

    % mutually exclusive
    aet(ds, 'ratio', .5, 'count', 2);

    aet(ds, 'repeats', [2 2]);
    aet(ds, 'repeats', -1);
    aet(ds, 'resamplings', [2 2]);
    aet(ds, 'resamplings', -1);
    aet(ds, 'resamplings', 1, 'repeats', 1);

    % not existing field
    ds_bad = ds;
    ds_bad.sa = rmfield(ds_bad.sa, 'targets');
    aet(ds_bad);

    % illegal split-by arguments
    aet(ds, 'split_by', []);
    aet(ds, 'split_by', struct());
    aet(ds, 'split_by', 'foo');
    aet(ds, 'split_by', {1, 2});

function test_average_samples_with_repeats
    nchunks = ceil(rand() * 4 + 3);
    ntargets = ceil(rand() * 4 + 3);
    ncombi_max = ceil(rand() * 3 + 4);

    max_cyc = 5;

    ncombi_min = ceil(ncombi_max / 2);

    ds = cosmo_synthetic_dataset('nchunks', nchunks, ...
                                 'ntargets', ntargets, ...
                                 'nreps', ncombi_max);
    ds.sa = rmfield(ds.sa, 'rep');
    sp = cosmo_split(ds, {'targets', 'chunks'});
    n_splits = numel(sp);

    % select subset of samples, each with at least ncombi_min repeats
    combi_count = zeros(nchunks, ntargets);

    for k = 1:n_splits
        if k == 1
            % ensure at least one with minimum
            nkeep = ncombi_min;
        else
            nkeep = ncombi_min + floor(rand() * (ncombi_max - ncombi_min));
        end

        ds_k = cosmo_slice(sp{k}, 1:nkeep);
        ds_k.sa.repeats = (1:nkeep)';

        combi_count(ds_k.sa.chunks(1), ds_k.sa.targets(1)) = nkeep;

        sp{k} = ds_k;
    end

    assert(all(cellfun(@(x)size(x.samples, 1), sp)));
    ds = cosmo_stack(sp);

    [nsamples, nfeatures] = size(ds.samples);

    % bit widths for features, chunks, targets, and repeats
    bws = [nfeatures, nchunks, ntargets, ceil(log2(max_cyc + 1)) + ncombi_max];

    % encode features, chunks, targets and repeats into single number
    dsb = binarize_ds(ds, bws);

    % helper function
    check_with = @(args, ...
                   count, ...
                   repeats) check_with_helper(dsb, args, count, repeats, ...
                                              nchunks, ntargets, ...
                                              ncombi_max, combi_count, ...
                                              bws);

    for repeats = [1, ceil(rand() * ncombi_max)]
        for count = [1, ceil(rand() * ncombi_min)]
            check_with({'count', count, 'repeats', repeats}, ...
                       count, repeats);
        end
        for ratio = [.5, .3 + rand() * .7]
            count = round(ratio * min(combi_count(:)));
            check_with({'ratio', ratio, 'repeats', repeats}, ...
                       count, repeats);
        end
    end

    for resamplings = [0, 1, 2 + round(rand() * 4)]
        count = ceil(rand() * ncombi_min);
        if resamplings == 0
            repeats = floor(ncombi_min / count);
            args = {'count', count};
        else
            repeats = floor(resamplings * ncombi_min / count);
            args = {'count', count, 'resamplings', resamplings};
        end

        check_with(args, count, repeats);
    end

function  check_with_helper(dsb, args, count, repeats, ...
                            nchunks, ntargets, ncombi_max, combi_count, bws)

    mu = cosmo_average_samples(dsb, args{:});
    [chunks, targets, ids] = unbinarize_ds(mu, bws, count);

    nsamples = size(ids, 1);
    nfeatures = size(dsb.samples, 2);

    % chunk, target, repeat count
    ctr_count = zeros(nchunks, ntargets, ncombi_max);

    % keep track of each target and chunk combination
    for j = 1:nsamples
        for k = 1:nfeatures
            % select same samples for all features
            id = ids{j, k};

            if k == 1
                first_id = id;
            else
                assertEqual(first_id, id);
            end
        end
        % no repeats
        id_sorted = sort(id(:));
        assert(all(diff(id_sorted) > 0));

        % count should match
        assertEqual(numel(id), count);

        ctr_count(chunks(j), targets(j), id) = ...
                    ctr_count(chunks(j), targets(j), id) + 1;

    end

    % ensure each sample selected about equally often
    [nchunks, ntargets] = size(combi_count);
    for k = 1:nchunks
        for j = 1:ntargets
            c = squeeze(ctr_count(k, j, :));

            pre = c(1:combi_count(k, j));
            assert(max(pre) - min(pre) <= 1);

            post = c((combi_count(k, j) + 1):end);
            assert(all(post == 0));
        end
    end

    % check each target and chunk combination was used the correct number
    % of times to form the average
    ct_count = sum(ctr_count, 3);

    expected_ct_count = count * repeats * ones(nchunks, ntargets);

    assert(isequal(ct_count, expected_ct_count));

function [chunks, targets, ids] = unbinarize_ds(ds, bws, counts)
    [nsamples, nfeatures] = size(ds.samples);

    ids = cell(nsamples, nfeatures);
    chunks = zeros(nsamples, 1);
    targets = zeros(nsamples, 1);

    for k = 1:nsamples
        for j = 1:nfeatures
            % Decode repeats; multiple repeats can be present.
            % As there can be multiple repeats, the averaging is undone
            % and then each bit represents just one repeat
            v_id = quick_dec2bin(mod(ds.samples(k, j) * counts, ...
                                     2^bws(end)), ...
                                 bws(end));
            ids{k, j} = bws(end) - find(v_id) + 1;

            % decode chunks, targets, ids
            v = decode(floor(ds.samples(k, j) / 2^bws(end)), bws(1:(end - 1)));
            assertEqual(log2(v(1)) + 1, j);
            c = log2(v(2)) + 1;
            t = log2(v(3)) + 1;

            if j == 1
                chunks(k) = c;
                targets(k) = t;
            else
                assertEqual(c, chunks(k));
                assertEqual(t, targets(k));
            end
        end
    end

function bds = binarize_ds(ds, bws)
    bds = ds;
    [nsamples, nfeatures] = size(ds.samples);
    for k = 1:nsamples
        sa = cosmo_slice(ds.sa, k, 1, 'struct');
        for j = 1:nfeatures
            vs = [j, sa.chunks sa.targets sa.repeats];

            bds.samples(k, j) = encode(vs, bws);
        end
    end

function p = encode(vs, bws)
    % encode several decimal numbers in a single one, through
    %     encode([X1 ... Xn]) = bin2dec([dec2bin(X1) ... dec2bin(Xn)])
    % where bws contains the bit width for each number
    n = numel(bws);
    assert(numel(vs) == n);

    bs = cell(1, n);
    for k = 1:n
        bw = bws(k);
        bs{k} = zeros(1, bw);
        bs{k}(bw - vs(k) + 1) = 1;
    end

    p = quick_bin2dec(cat(2, bs{:}));

function vs = decode(p, bws)
    % encode single decimal numbers in multiple ones, through
    %     decode(P) = [bin2dec(PB1) ... bin2dec(PBn)]
    %     with PBi the binary representation part of P for each binary
    %     representation part

    arr = quick_dec2bin(p, sum(bws));

    c = 0;
    n = numel(bws);
    vs = zeros(1, n);
    for k = 1:n
        offset = bws(k);
        vs(k) = quick_bin2dec(arr(c + (1:offset)));
        c = c + offset;
    end

function arr = quick_dec2bin(x, bw)
    % converts decimal number x to array with length bw and all
    % values in 0 and 1
    assert(round(x) == x);
    arr = zeros(1, bw);

    xbs = dec2bin(x);
    arr(bw - numel(xbs) + 1:end) = (xbs == '1');
    return

function x = quick_bin2dec(arr)
    % convert binary array to decimal number
    x = sum(2.^((numel(arr) - 1):-1:0) .* arr);