test crossvalidation measure

function test_suite = test_crossvalidation_measure
    % tests for cosmo_crossvalidation_measure
    %
    % #   For CoSMoMVPA's copyright information and license terms,   #
    % #   see the COPYING file distributed with CoSMoMVPA.           #
    try % assignment of 'localfunctions' is necessary in Matlab >= 2016
        test_functions = localfunctions();
    catch % no problem; early Matlab versions can use initTestSuite fine
    end
    initTestSuite;

function test_crossvalidation_measure_regression
    ds = cosmo_synthetic_dataset('ntargets', 6, 'nchunks', 4);
    ds.sa.targets = ds.sa.targets + 10;
    ds.sa.chunks = ds.sa.chunks + 20;

    opt = struct();
    opt.partitions = cosmo_nfold_partitioner(ds);
    opt.classifier = @cosmo_classify_lda;

    res = cosmo_crossvalidation_measure(ds, opt);
    assertElementsAlmostEqual(res.samples, 0.6250);
    assertEqual(res.sa, cosmo_structjoin('labels', {'accuracy'}));

    opt.output = 'accuracy';
    res2 = cosmo_crossvalidation_measure(ds, opt);
    assertEqual(res, res2);

    opt.output = 'winner_predictions';
    res3 = cosmo_crossvalidation_measure(ds, opt);
    assertEqual(res3.samples, 10 + [1 2 3 4 5 5 4 6 2 4 2 6 ...
                                    1 2 3 4 6 6 1 3 3 4 3 1]');
    assertEqual(res3.sa, rmfield(ds.sa, 'chunks'));

    % use deprecated output options
    warning_state = cosmo_warning();
    warning_state_resetter = onCleanup(@()cosmo_warning(warning_state));
    cosmo_warning('off');
    opt.output = 'winner_predictions';
    res3a = cosmo_crossvalidation_measure(ds, opt);
    assertEqual(res3, res3a);

    opt.output = 'fold_accuracy';
    res4 = cosmo_crossvalidation_measure(ds, opt);
    assertElementsAlmostEqual(res4.samples, [5 2 5 3]' / 6);
    assertEqual(res4.sa.folds, (1:4)');

    % test different classifier
    opt.classifier = @cosmo_classify_nn;
    opt.partitions = cosmo_nfold_partitioner(ds);
    opt.output = 'winner_predictions';

    res6 = cosmo_crossvalidation_measure(ds, opt);
    assertEqual(res6.samples, 10 + [1 2 3 1 5 6 4 6 5 4 6 6 6 ...
                                    2 3 4 2 5 1 2 3 4 3 1]');
    % test normalization option
    opt.normalization = 'zscore';
    res7 = cosmo_crossvalidation_measure(ds, opt);
    assertEqual(res7.samples, 10 + [1 2 3 5 5 6 4 6 5 4 6 6 6 ...
                                    2 3 4 5 5 1 5 3 1 3 1]');

    % test with averaging samples
    opt = rmfield(opt, 'normalization');
    opt.average_train_count = 1;
    res8 = cosmo_crossvalidation_measure(ds, opt);
    assertEqual(res8.samples, 10 + [1 2 3 1 5 6 4 6 5 4 6 6 6 ...
                                    2 3 4 2 5 1 2 3 4 3 1]');

    opt.average_train_count = 2;
    opt.average_train_resamplings = 5;
    res9 = cosmo_crossvalidation_measure(ds, opt);
    assertEqual(res9.samples, 10 + [1 2 3 4 5 6 4 6 2 4 6 6 1 ...
                                    2 3 4 5 6 1 2 3 4 5 1]');

function test_fold_accuracy()
    randint = @()ceil(rand() * 5) + 5;

    ntargets = randint();
    ds = cosmo_synthetic_dataset('ntargets', ntargets, ...
                                 'nchunks', randint(), ...
                                 'nreps', randint());
    ds.samples(:) = randn(size(ds.samples));
    ds.sa.targets = ds.sa.targets + 10;
    ds.sa.chunks = ds.sa.chunks + 20;

    partitions = cosmo_nchoosek_partitioner(ds, 3);

    opt = struct();
    opt.partitions = partitions;
    opt.classifier = @cosmo_classify_nn;
    opt.output = 'fold_accuracy';

    nfolds = numel(opt.partitions.train_indices);

    res = cosmo_crossvalidation_measure(ds, opt);
    assertEqual(size(res.samples), [nfolds, 1]);
    assertEqual(size(res.sa.folds), [nfolds, 1]);

    for fold = 1:nfolds
        f_opt = opt;
        f_opt.partitions.train_indices = partitions.train_indices(fold);
        f_opt.partitions.test_indices = partitions.test_indices(fold);
        f_res = cosmo_crossvalidation_measure(ds, f_opt);
        assertElementsAlmostEqual(res.samples(fold), f_res.samples);
    end

function test_fold_predictions
    randint = @()ceil(rand() * 4) + 1;

    ntargets = randint();
    targets_offset = randint();
    nchunks = randint() + 1;
    ds = cosmo_synthetic_dataset('ntargets', ntargets, ...
                                 'nchunks', nchunks, ...
                                 'nreps', randint());
    ds.samples(:) = randn(size(ds.samples));
    ds.sa.targets = ds.sa.targets + targets_offset;
    ds.sa.chunks = ds.sa.chunks + 20;

    opt = struct();
    opt.partitions = cosmo_nchoosek_partitioner(ds, ceil(nchunks / 2));
    opt.classifier = @cosmo_classify_nn;
    opt.output = 'fold_predictions';

    train_idx = opt.partitions.train_indices;
    test_idx = opt.partitions.test_indices;

    nfolds = numel(train_idx);
    nsamples = size(ds.samples, 1);

    % using crossvalidation_measure
    res = cosmo_crossvalidation_measure(ds, opt);

    % using crossvalidate function
    [cv_pred, acc] = cosmo_crossvalidate(ds, opt.classifier, opt.partitions);

    visited = false(size(res.samples));
    for k = 1:nfolds
        % test crossvalidation_measure
        msk = res.sa.folds == k;
        visited(msk) = true;
        pred = res.samples(msk, :);
        assertEqual(size(pred), [numel(test_idx{k}), 1]);
        assertEqual(res.sa.targets(msk), ds.sa.targets(test_idx{k}));

        % compare with classifier output
        fold_pred = opt.classifier(ds.samples(train_idx{k}, :), ...
                                   ds.sa.targets(train_idx{k}), ...
                                   ds.samples(test_idx{k}, :));
        assertEqual(fold_pred, pred);

        % check comso_crossvalidate output
        nan_msk = true(nsamples, 1);
        nan_msk(test_idx{k}) = false;
        assertEqual(isnan(cv_pred(:, k)), nan_msk);
        assertEqual(cv_pred(~nan_msk, k), fold_pred);
    end
    assert(all(visited));

    % fields should only be targets and folds
    assertEqual(sort(fieldnames(res.sa)), ...
                sort({'targets'; 'folds'}));

    % test accuracy
    pred_msk = ~isnan(cv_pred);
    correct_pred = bsxfun(@eq, cv_pred, ds.sa.targets) & pred_msk;
    assertElementsAlmostEqual(acc, sum(correct_pred) / sum(pred_msk));

    % test with winner_predictions
    opt.output = 'winner_predictions';
    res = cosmo_crossvalidation_measure(ds, opt);
    assertEqual(size(res.samples), [nsamples, 1]);

    for row = 1:nsamples
        h = histc(cv_pred(row, :), (1:ntargets) + targets_offset);

        % predicted sample is a winner
        row_pred = res.samples(row) - targets_offset;
        assert(all(h <= h(row_pred)));

        % correct winner
        assert(h(row_pred) == max(h));
    end

function test_crossvalidation_measure_deprecations
    warning_state = cosmo_warning();
    state_resetter = onCleanup(@()cosmo_warning(warning_state));

    deprecated_outputs = {'predictions', 'raw'};

    ds = cosmo_synthetic_dataset();
    opt = struct();
    opt.classifier = @cosmo_classify_nn;
    opt.partitions = cosmo_nfold_partitioner(ds);

    for i_output = 1:numel(deprecated_outputs)
        cosmo_warning('reset');
        cosmo_warning('off');

        output = deprecated_outputs{i_output};
        opt.output = output;

        % run the measure
        cosmo_crossvalidation_measure(ds, opt);

        % must have shown a warning
        s = cosmo_warning();
        w = s.shown_warnings;
        assertTrue(numel(w) >= 1, 'no warning was shown');
    end

function test_crossvalidation_measure_exceptions
    aet = @(varargin)assertExceptionThrown(@() ...
                                           cosmo_crossvalidation_measure(varargin{:}), '');
    bad_opt = struct();
    bad_opt.partitions = struct();
    bad_opt.classifier = @abs;
    aet(struct, bad_opt);

    ds = cosmo_synthetic_dataset();
    opt = struct();
    opt.partitions = cosmo_nfold_partitioner(ds);
    opt.classifier = @cosmo_classify_lda;

    aet(struct, opt);

    bad_opt = opt;
    bad_opt.output = 'foo';
    aet(ds, bad_opt);

    bad_opt = opt;
    bad_opt.output = 'accuracy_by_chunk'; % not supported anymore
    aet(ds, bad_opt);

function test_balanced_accuracy()
    nclasses = 10;
    nchunks = 20;
    ds = cosmo_synthetic_dataset('ntargets', nclasses, ...
                                 'nchunks', nchunks, ...
                                 'nreps', 4);

    % shuffle targets, use random data - assume data is unbalanced
    % afterwards
    ds.samples = randn(size(ds.samples));

    nsamples = size(ds.samples, 1);
    while true
        ds.sa.targets = ceil(rand(nsamples, 1) * nclasses);
        ds.sa.chunks = ceil(rand(nsamples, 1) * nchunks);

        h_t = histc(ds.sa.targets, 1:nclasses);
        h_c = histc(ds.sa.chunks, 1:nchunks);

        if numel(h_t) ~= nclasses || ...
                numel(h_c) ~= nchunks
            % classes or chunsk missing, regenerate
            continue
        end

        if any(h_t ~= nclasses) && any(h_c ~= nchunks)
            % imbalance
            break
        end
    end

    % keep subset of all partitions, so that there are missing predictions
    % for some of the samples
    partitions = cosmo_nfold_partitioner(ds);
    nkeep = ceil(.3 * nchunks);
    partitions.train_indices = partitions.train_indices(1:nkeep);
    partitions.test_indices = partitions.test_indices(1:nkeep);

    % compute balanced accuracy
    opt = struct();
    opt.classifier = @cosmo_classify_nn;
    opt.partitions = partitions;

    % without check_partitions, an exception should be thrown as the
    % partitions are supposed to be unbalanced
    assertExceptionThrown(@() ...
                          cosmo_check_partitions(partitions, ds), '');
    assertExceptionThrown(@() ...
                          cosmo_crossvalidation_measure(ds, opt), '');
    opt.check_partitions = false;

    % compute accuracy
    opt.output = 'balanced_accuracy';
    ba_result = cosmo_crossvalidation_measure(ds, opt);

    opt.output = 'winner_predictions';
    pred_result = cosmo_crossvalidation_measure(ds, opt);

    opt.output = 'accuracy';
    acc_result = cosmo_crossvalidation_measure(ds, opt);

    % check fields
    result_cell = {ba_result, acc_result};
    for k = 1:numel(result_cell)
        result = result_cell{k};

        assertEqual(sort(fieldnames(result)), sort({'samples'; 'sa'}));
        assertEqual(fieldnames(result.sa), {'labels'});
    end

    assertEqual(ba_result.sa.labels, {'balanced_accuracy'});
    assertEqual(acc_result.sa.labels, {'accuracy'});

    % compute expected result for balanced accuracy
    [unused, unused, target_idx] = unique(ds.sa.targets);
    assert(max(target_idx) == nclasses);
    nfolds = numel(partitions.train_indices);

    correct_count = zeros(nfolds, nclasses);
    class_count = zeros(1, nclasses);

    all_pred = NaN(nsamples, 1);

    for fold_i = 1:nfolds
        tr_idx = partitions.train_indices{fold_i};
        te_idx = partitions.test_indices{fold_i};

        ds_tr = cosmo_slice(ds, tr_idx);
        ds_te = cosmo_slice(ds, te_idx);

        target_idx_te = target_idx(te_idx);

        pred = opt.classifier(ds_tr.samples, ...
                              ds_tr.sa.targets, ...
                              ds_te.samples);
        all_pred(te_idx) = pred;
        for class_i = 1:nclasses
            target_msk = target_idx_te == class_i;
            is_correct = pred(target_msk) == ds_te.sa.targets(target_msk);
            correct_count(fold_i, class_i) = sum(is_correct);
            class_count(class_i) = class_count(class_i) + numel(is_correct);
        end
    end

    class_acc = bsxfun(@rdivide, sum(correct_count, 1), class_count);

    % verify expected result for balanced accuracy
    assertElementsAlmostEqual(mean(class_acc), ba_result.samples);

    % verify expected result for predictions of each class
    assertEqual(pred_result.samples, all_pred);
    assertEqual(pred_result.sa.targets, ds.sa.targets);

function test_pca()
    ntargets = 2;
    nchunks = 5;

    nfeatures = ceil(rand() * 10 + 10);
    nsamples = ntargets * nchunks * 4 * nfeatures;

    idxs = (1:nsamples)' - 1;

    ds = struct();
    ds.samples = randn(nsamples, nfeatures);
    ds.sa.targets = mod(idxs, ntargets) + 1;
    ds.sa.chunks = mod(floor(idxs / (ntargets * nchunks)), nchunks) + 1;

    test_msk = ds.sa.chunks == nchunks;
    partitions = struct();
    partitions.train_indices = {find(~test_msk)};
    partitions.test_indices = {find(test_msk)};

    opt = struct();
    opt.partitions = partitions;
    opt.classifier = @cosmo_classify_lda;
    opt.output = 'winner_predictions';

    for count = [1 ceil(nfeatures / 2) nfeatures ceil(rand() * nfeatures)]
        opt_count = opt;
        opt_count.pca_explained_count = count;
        helper_test_pca_count(ds, opt_count, count);
    end

    for ratio = [.1 .5 .9 1 rand()]
        opt_ratio = opt;
        opt_ratio.pca_explained_ratio = ratio;
        helper_test_pca_ratio(ds, opt_ratio, ratio);
    end

function helper_test_pca_count(ds, opt, count)
    pred_full = cosmo_crossvalidation_measure(ds, opt);

    % compute results manually
    [expected_pred, test_indices] = helper_pca_crossval_single_fold(ds, ...
                                                                    opt, count);
    % compare results
    assertEqual(expected_pred, ...
                pred_full.samples(test_indices));

function [pred, test_indices] = helper_pca_crossval_single_fold(ds, opt, count)
    partitions = opt.partitions;
    assert(numel(partitions.train_indices) == 1);
    assert(numel(partitions.test_indices) == 1);
    ds_train = cosmo_slice(ds, partitions.train_indices{1});
    [tr_pca, params] = cosmo_pca(ds_train.samples, count);

    test_indices = partitions.test_indices{1};
    ds_test = cosmo_slice(ds, test_indices);
    te_pca = bsxfun(@minus, ds_test.samples, params.mu) * params.coef;

    pred = opt.classifier(tr_pca, ds_train.sa.targets, te_pca);

function helper_test_pca_ratio(ds, opt, ratio)
    partitions = opt.partitions;
    assert(numel(partitions.train_indices) == 1);
    ds_train = cosmo_slice(ds, partitions.train_indices{1});
    [unused, params] = cosmo_pca(ds_train.samples);

    count = find(cumsum(params.explained) >= ratio * 100, 1);
    if isempty(count)
        count = numel(params.explained);
    end

    % delegate to count helepr
    helper_test_pca_count(ds, opt, count);

function test_crossvalidation_measure_pca_exceptions
    aet = @(varargin)assertExceptionThrown(@() ...
                                           cosmo_crossvalidation_measure(varargin{:}), '');
    ds = cosmo_synthetic_dataset();

    opt = struct();
    opt.classifier = @cosmo_classify_lda;
    opt.partitions = cosmo_nfold_partitioner(ds);

    % mutually exclusive parameters
    bad_opt = opt;
    bad_opt.pca_explained_count = 2;
    bad_opt.pca_explained_ratio = .5;

    aet(ds, bad_opt);