test classify

function test_suite = test_classify
    % tests for cosmo_classify_* functions
    %
    % #   For CoSMoMVPA's copyright information and license terms,   #
    % #   see the COPYING file distributed with CoSMoMVPA.           #
    try % assignment of 'localfunctions' is necessary in Matlab >= 2016
        test_functions = localfunctions();
    catch % no problem; early Matlab versions can use initTestSuite fine
    end
    initTestSuite;

function test_general_classifiers_strong_signal()
    % note: part of these checks are also implemented in
    % general_test_classifier (included below)

    if cosmo_wtf('is_matlab')
        csvm_external = '!fitcsvm';
    else
        non_existing_external = '!';
        csvm_external = non_existing_external;
    end

    cfy_cell = {{@cosmo_classify_lda, ''}, ...
                {@cosmo_classify_nn, ''}, ...
                {@cosmo_classify_naive_bayes, ''}, ...
                {@cosmo_classify_libsvm, 'libsvm'}, ...
                {@cosmo_classify_matlabsvm, 'matlabsvm'}, ...
                {@cosmo_classify_matlabcsvm, csvm_external}};

    n_cfy = numel(cfy_cell);

    no_information = 0;
    strong_information = 2 + rand();
    for sigma = [no_information, strong_information]
        [tr_s, tr_t, te_s, te_t] = generate_informative_data(sigma);

        for k = 1:n_cfy
            cfy = cfy_cell{k}{1};
            predictor_func = @()cfy(tr_s, tr_t, te_s);

            external = cfy_cell{k}{2};
            if ~strcmp(external, '')
                if ~cosmo_check_external(external, false)
                    assertExceptionThrown(predictor_func, '*');
                    continue
                end
            end

            pred = cfy(tr_s, tr_t, te_s);
            acc = mean(pred == te_t);

            assertEqual(numel(pred), 400);

            if sigma == no_information
                assertTrue(0.4 < acc && acc < 0.6);
            elseif sigma == strong_information
                assertTrue(acc > 0.9);
            else
                assertFalse(true);
            end
        end
    end

function test_classify_lda
    cfy = @cosmo_classify_lda;
    handle = get_predictor(cfy);
    assert_predictions_equal(handle, [1 3 9 8 5 6 8 9 7 5 7 5 4 ...
                                      9 2 7 7 7 1 2 1 1 7 6 7 1 7]');
    general_test_classifier(cfy);

    aet = @(varargin)assertExceptionThrown(@() ...
                                           cosmo_classify_lda(varargin{:}), '');
    % size mismatch
    x = randn(5, 3);
    y = randn(2, 3);
    aet(x, [1 1 1 2 2], y);

    % too many features
    x = zeros(1, 1e4);
    aet(x, 1, 1e4);

function test_classify_naive_bayes
    cfy = @cosmo_classify_naive_bayes;
    handle = get_predictor(cfy);
    assert_predictions_equal(handle, [1 7 3 9 2 2 8 9 7 4 7 2 4 ...
                                      8 2 7 7 7 1 2 7 1 7 2 7 1 9]');
    general_test_classifier(cfy);

function test_classify_meta_feature_selection
    %     % Oct 2016: got strange glibc error when using travis
    %
    %     ** glibc detected *** /usr/bin/octave-cli: double free or corruption (!prev): 0x0000000003570030 ***
    %     ======= Backtrace: =========
    %     /lib/x86_64-linux-gnu/libc.so.6(+0x7da26)[0x2abf370a4a26]
    %     /usr/lib/x86_64-linux-gnu/liboctinterp.so.3(_ZN12symbol_table7cleanupEv+0x7d)[0x2abf3678001d]
    %     /usr/lib/x86_64-linux-gnu/liboctinterp.so.3(+0xb486de)[0x2abf367b26de]
    %     /usr/lib/x86_64-linux-gnu/liboctinterp.so.3(_Z17clean_up_and_exitib+0x17)[0x2abf367b4057]
    %     /usr/lib/x86_64-linux-gnu/liboctinterp.so.3(octave_execute_interpreter+0xd1a)[0x2abf35f341ca]
    %     /lib/x86_64-linux-gnu/libc.so.6(__libc_start_main+0xed)[0x2abf370487ed]
    %
    %     https://travis-ci.org/nno/CoSMoMVPA/builds/166411636
    %
    %     Fix seemded to use cosmo_classify_nn instead of cosmo_classify_lda
    %     as child_classifier - but that may just be describing a symptom, not
    %     the root cause.
    cfy = @cosmo_classify_meta_feature_selection;
    opt = struct();
    opt.child_classifier = @cosmo_classify_nn;
    opt.feature_selector = @cosmo_anova_feature_selector;
    opt.feature_selection_ratio_to_keep = .6;
    handle = get_predictor(cfy, opt);
    assert_predictions_equal(handle, [1 1 7 1 6 6 1 1 7 5 1 1 4 ...
                                      9 6 7 7 1 1 2 1 3 9 6 7 8 7]');
    general_test_classifier(cfy, opt);

function test_cosmo_meta_feature_selection_classifier
    % deprecated, so shows a warning
    warning_state = cosmo_warning();
    warning_state_resetter = onCleanup(@()cosmo_warning(warning_state));
    cosmo_warning('off');

    cfy = @cosmo_meta_feature_selection_classifier;
    opt = struct();
    opt.child_classifier = @cosmo_classify_nn;
    opt.feature_selector = @cosmo_anova_feature_selector;
    opt.feature_selection_ratio_to_keep = .6;
    handle = get_predictor(cfy, opt);
    assert_predictions_equal(handle, [1 1 7 1 6 6 1 1 7 5 1 1 4 ...
                                      9 6 7 7 1 1 2 1 3 9 6 7 8 7]');
    general_test_classifier(cfy, opt);

function test_classify_nn
    cfy = @cosmo_classify_nn;
    handle = get_predictor(cfy);
    assert_predictions_equal(handle, [1 3 6 8 6 6 8 7 7 5 7 7 4 ...
                                      9 7 7 7 7 1 2 7 1 5 6 7 1 9]');
    general_test_classifier(cfy);

function test_classify_knn
    cfy = @cosmo_classify_knn;
    opt = struct();
    opt.knn = 2;

    handle = get_predictor(cfy, opt);
    assert_predictions_equal(handle, [7 3 6 1 7 2 8 7 9 4 7 5 7 1 ...
                                      7 7 7 8 1 6 1 1 9 5 8 1 9]');
    general_test_classifier(cfy, opt);

function test_classify_matlabsvm
    warning_state = cosmo_warning();
    cleaner = onCleanup(@()cosmo_warning(warning_state));
    cosmo_warning('off');

    cfy = @cosmo_classify_matlabsvm;
    handle = get_predictor(cfy);
    if ~cosmo_check_external('matlabsvm', false)
        assert_throws_illegal_input_exceptions(cfy);
        assertExceptionThrown(handle, '');
        notify_test_skipped('matlabsvm');
        return
    end

    assert_predictions_equal(handle, [1 3 9 7 6 6 9 3 7 5 6 6 4 ...
                                      1 7 7 7 7 1 7 7 1 7 6 7 1 9]');
    general_test_classifier(cfy);

function test_classify_matlabcsvm
    warning_state = cosmo_warning();
    cleaner = onCleanup(@()cosmo_warning(warning_state));
    cosmo_warning('off');

    cfy = @cosmo_classify_matlabcsvm;
    handle = get_predictor(cfy);
    if ~cosmo_check_external('matlabcsvm', false) || ...
            ~cosmo_check_external('!fitcsvm', false)
        assert_throws_illegal_input_exceptions(cfy);
        assertExceptionThrown(handle);
        notify_test_skipped('matlabcsvm');
        return
    end

    assert_predictions_equal(handle, [1 2 3 4 5 6 7 8 9 1 2 3 4 ...
                                      5 6 7 8 9 1 2 3 4 5 6 7 8 9]');
    general_test_classifier(cfy);

function test_classify_matlabsvm_2class
    warning_state = cosmo_warning();
    cleaner = onCleanup(@()cosmo_warning(warning_state));
    cosmo_warning('off');

    cfy = @cosmo_classify_matlabsvm_2class;
    handle = get_predictor(cfy);
    if ~cosmo_check_external('matlabsvm', false)
        assert_throws_illegal_input_exceptions(cfy);
        assertExceptionThrown(handle, '');
        notify_test_skipped('matlabsvm');
        return
    end

    assertExceptionThrown(handle, ''); % cannot deal with nine classes

    handle = get_predictor(cfy, struct(), 2);

    assert_predictions_equal(handle, [1 2 2 2 1 2]');
    general_test_classifier(cfy);

    % test non-convergence
    aet = @(exc, varargin)assertExceptionThrown(@() ...
                                                cosmo_classify_matlabsvm_2class(varargin{:}), exc);
    opt = struct();
    opt.options.MaxIter = 1;
    aet('', [0 0; 0 1; 1 0; 1 1], [1 2 2 1], NaN(2), opt);
    opt.tolkkt = struct();
    aet('stats:svmtrain:badTolKKT', ...
        [0 0; 0 1; 1 0; 1 1], [1 2 2 1], NaN(2), opt);

function test_classify_libsvm_with_autoscale
    warning_state = cosmo_warning();
    cleaner = onCleanup(@()cosmo_warning(warning_state));
    cosmo_warning('off');

    cfy = @cosmo_classify_libsvm;
    opt = struct();
    opt.autoscale = true;
    handle = get_predictor(cfy, opt);
    if ~cosmo_check_external('libsvm', false)
        assertExceptionThrown(handle, '');
        notify_test_skipped('libsvm');
        return
    end

    assert_predictions_equal(handle, [8 3 3 8 6 6 1 3 7 5 7 6 4 ...
                                      1 2 7 7 7 1 2 8 1 9 6 7 1 3]');
    general_test_classifier(cfy, opt);

function test_classify_libsvm_no_autoscale
    warning_state = cosmo_warning();
    cleaner = onCleanup(@()cosmo_warning(warning_state));
    cosmo_warning('off');

    cfy = @cosmo_classify_libsvm;
    opt = struct();
    opt.autoscale = false;
    handle = get_predictor(cfy, opt);
    if ~cosmo_check_external('libsvm', false)
        assertExceptionThrown(handle, '');
        cosmo_notify_test_skipped('libsvm');
        return
    end

    assert_predictions_equal(handle, [8 3 3 8 6 6 1 3 7 5 7 6 4 ...
                                      1 2 7 7 7 1 2 8 1 9 6 7 1 3]');
    general_test_classifier(cfy, opt);

function test_classify_libsvm_t0
    % test with default (linear kernel) type t=0
    cfy = @cosmo_classify_libsvm;

    % with or without some of the default options; all should give the same
    % result
    params = {{}, ...
              {'t', '0'}, ...
              {'t', 0}, ...
              {'t', 0, 'autoscale', false, 's', '0', 'r', 0, 'c', 1, 'h', 1}};
    n = numel(params);

    for k = 1:n
        param = params{k};
        if isempty(param)
            opt = [];
            opt_struct = struct();
        else
            opt = cosmo_structjoin(param);
            opt_struct = opt;
        end
        handle = get_predictor(cfy, opt);
        if ~cosmo_check_external('libsvm', false)
            assertExceptionThrown(handle, '');
            notify_test_skipped('libsvm');
            return
        end

        assert_predictions_equal(handle, [8 3 3 8 6 6 1 3 7 5 7 6 4 ...
                                          1 2 7 7 7 1 2 8 1 9 6 7 1 3]');
        general_test_classifier(cfy, opt_struct);
    end

function test_classify_libsvm_t2
    cfy = @cosmo_classify_libsvm;
    opt = struct();
    opt.t = 2;
    handle = get_predictor(cfy, opt);
    if ~cosmo_check_external('libsvm', false)
        assertExceptionThrown(handle, '');
        notify_test_skipped('libsvm');
        return
    end

    % libsvm uses autoscale by default
    assert_predictions_equal(handle, [1 3 6 8 6 6 8 3 7 5 7 5 4 ...
                                      5 7 7 7 8 1 2 8 1 5 5 7 1 7]');
    general_test_classifier(cfy, opt);

function test_classify_svm
    clear cosmo_check_external();
    cfy = @cosmo_classify_svm;
    handle = get_predictor(cfy);
    if ~cosmo_check_external('svm', false)
        assertExceptionThrown(handle, '');
        notify_test_skipped('svm');
        return
    end

    % matlab and libsvm show slightly different results
    if cosmo_check_external('libsvm', false)
        pred = [8 3 3 8 6 6 1 3 7 5 7 6 4 1 2 7 7 7 1 2 8 1 9 6 7 1 3]';

        good_opt = struct();
        good_opt.svm = 'libsvm';
        bad_opt = struct();
        bad_opt.svm = 'matlabsvm';
    else
        % do not show warning message
        warning_state = cosmo_warning();
        cleaner = onCleanup(@()cosmo_warning(warning_state));
        cosmo_warning('off');

        pred = [1 3 9 7 6 6 9 3 7 5 6 6 4 1 7 7 7 7 1 7 7 1 7 6 7 1 9]';

        good_opt = struct();
        good_opt.svm = 'matlabsvm';
        bad_opt = struct();
        bad_opt.svm = 'libsvm';
    end

    assert_predictions_equal(handle, pred);
    general_test_classifier(cfy);
    general_test_classifier(cfy, good_opt);
    assertExceptionThrown(@()general_test_classifier(cfy, bad_opt), '');

function general_test_classifier(cfy_base, opt)
    if nargin < 2
        cfy = cfy_base;
    else
        cfy = @(x, y, z)cfy_base(x, y, z, opt);
    end
    assert_chance_null_data(cfy);
    assert_above_chance_informative_data(cfy);
    assert_throws_expected_exceptions(cfy_base, cfy);

function assert_chance_null_data(cfy)
    assert_accuracy_in_range(cfy, 0, 0.3, 0.7);

function assert_above_chance_informative_data(cfy)
    assert_accuracy_in_range(cfy, 4, 0.8, 1);

function assert_accuracy_in_range(cfy, sigma, min_val, max_val)
    [tr_s, tr_t, te_s, te_t] = generate_informative_data(sigma);

    pred = cfy(tr_s, tr_t, te_s);
    acc = mean(pred == te_t);

    assertTrue(acc >= min_val);
    assertTrue(acc <= max_val);

function [tr_s, tr_t, te_s, te_t] = generate_informative_data(sigma)
    nclasses = 2;
    nsamples_per_class = 200;
    nsamples = nclasses * nsamples_per_class;
    nfeatures = 10;

    common_s = randn(1, nfeatures) * sigma;
    targets = repmat((1:nclasses)', nsamples_per_class, 1);

    tr_s = randn(nsamples, nfeatures);
    te_s = randn(nsamples, nfeatures);
    tr_t = targets;
    te_t = targets;

    for k = 1:nfeatures
        msk = targets == (mod(k - 1, nclasses) + 1);
        tr_s(msk, k) = tr_s(msk, k) + common_s(:, k);
        te_s(msk, k) = te_s(msk, k) + common_s(:, k);
    end

function assert_throws_expected_exceptions(cfy_base, cfy)
    assert_throws_illegal_input_exceptions(cfy);
    assert_deals_with_empty_input(cfy_base, cfy);

function assert_throws_illegal_input_exceptions(cfy)
    warning_state = cosmo_warning();
    state_resetter = onCleanup(@()cosmo_warning(warning_state));
    cosmo_warning('off');

    assertExceptionThrown(@()cfy([1 2], [1; 2], [1 2]), '');
    assertExceptionThrown(@()cfy([1; 2], [1 2], [1 2]), '');
    assertExceptionThrown(@()cfy([1 2], [1 2], [1 2]), '');
    assertExceptionThrown(@()cfy([1 2], 1, [1; 2]), '');
    assertExceptionThrown(@()cfy([1; 2], 1, [1 2]), '');
    assertExceptionThrown(@()cfy([1 2; 3 4; 5 6], [1; 1], [1 2]), '');
    assertExceptionThrown(@()cfy([1 2; 3 4; 5 6], [1; 1; 1], [1 2 3]), '');

function assert_deals_with_empty_input(cfy_base, cfy)
    % should pass
    non_one_class_classifiers = {@cosmo_classify_matlabsvm_2class, ...
                                 @cosmo_classify_meta_feature_selection, ...
                                 @cosmo_meta_feature_selection_classifier, ...
                                 @cosmo_classify_matlabcsvm};

    if cosmo_check_external('octave_pkg_statistics_libsvm', false)
        non_one_class_classifiers = [non_one_class_classifiers, ...
                                     {@cosmo_classify_libsvm, ...
                                      @cosmo_classify_svm}];
    end

    can_handle_single_class = ~any(cellfun(@(x)isequal(cfy_base, x), ...
                                           non_one_class_classifiers));

    if can_handle_single_class
        cfy([1 2; 3 4], [1; 1], [1 2]);
        cfy([1 2; 3 4; 5 6], [1; 1; 1], [1 2]);
    end

    % no features, should still make prediction
    res = cfy(zeros(4, 0), [1 1 2 2]', zeros(2, 0));
    assertEqual(size(res), [2 1]);
    assertTrue(all(res == 1 | res == 2));

    res2 = cfy(zeros(4, 0), [1 1 2 2]', zeros(2, 0));
    assertEqual(res, res2);

function handle = get_predictor(cfy, opt, nclasses)
    if nargin < 3
        nclasses = 9;
    end
    if nargin < 2 || isempty(opt)
        opt_arg = {};
    else
        opt_arg = {opt};
    end
    [tr_samples, tr_targets, te_samples] = generate_data(nclasses);
    handle = @()cfy(tr_samples, tr_targets, te_samples, opt_arg{:});

function assert_predictions_equal(handle, targets)
    pred = handle();
    assertEqual(pred, targets);

    % test caching, if implemented
    pred2 = handle();
    assertEqual(pred, pred2);

function [tr_samples, tr_targets, te_samples] = generate_data(nclasses)
    ds = cosmo_synthetic_dataset('ntargets', nclasses, ...
                                 'nchunks', 6);
    te_msk = ds.sa.chunks <= 3;
    tr_msk = ~te_msk;
    tr_targets = ds.sa.targets(tr_msk);
    tr_samples = ds.samples(tr_msk, :);
    te_samples = ds.samples(te_msk, :);

function notify_test_skipped(external)
    assertTrue(cosmo_skip_test_if_no_external(external));