test classify

function test_suite = test_classify
% tests for cosmo_classify_* functions
%
% #   For CoSMoMVPA's copyright information and license terms,   #
% #   see the COPYING file distributed with CoSMoMVPA.           #
    try % assignment of 'localfunctions' is necessary in Matlab >= 2016
        test_functions=localfunctions();
    catch % no problem; early Matlab versions can use initTestSuite fine
    end
    initTestSuite;

function test_general_classifiers_strong_signal()
    % note: part of these checks are also implemented in
    % general_test_classifier (included below)

    cfy_cell={{@cosmo_classify_lda,''},...
              {@cosmo_classify_nn,''},...
              {@cosmo_classify_naive_bayes,''},...
              {@cosmo_classify_libsvm,'libsvm'},...
              {@cosmo_classify_matlabsvm,'matlabsvm'},...
              {@cosmo_classify_matlabcsvm,'matlabcsvm'}};

    n_cfy=numel(cfy_cell);

    no_information=0;
    strong_information=2+rand();
    for sigma=[no_information, strong_information]
        [tr_s,tr_t, te_s, te_t]=generate_informative_data(sigma);

        for k=1:n_cfy
            cfy=cfy_cell{k}{1};
            predictor_func=@()cfy(tr_s, tr_t, te_s);

            external=cfy_cell{k}{2};
            if ~strcmp(external,'')
                if ~cosmo_check_external(external,false)
                    assertExceptionThrown(predictor_func,'*');
                    continue;
                end
            end

            pred = cfy(tr_s, tr_t, te_s);
            acc = mean(pred==te_t);

            assertEqual(numel(pred), 400);

            if sigma==no_information
                assertTrue(0.4 < acc && acc < 0.6);
            elseif sigma==strong_information
                assertTrue(acc > 0.9);
            else
                assertFalse(true);
            end
        end
    end


function test_classify_lda
    cfy=@cosmo_classify_lda;
    handle=get_predictor(cfy);
    assert_predictions_equal(handle,[1 3 9 8 5 6 8 9 7 5 7 5 4 ...
                                    9 2 7 7 7 1 2 1 1 7 6 7 1 7 ]');
    general_test_classifier(cfy);

    aet=@(varargin)assertExceptionThrown(@()...
                        cosmo_classify_lda(varargin{:}),'');
    % size mismatch
    x=randn(5,3);
    y=randn(2,3);
    aet(x,[1 1 1 2 2],y);

    % too many features
    x=zeros(1,1e4);
    aet(x,1,1e4);


function test_classify_naive_bayes
    cfy=@cosmo_classify_naive_bayes;
    handle=get_predictor(cfy);
    assert_predictions_equal(handle,[1 7 3 9 2 2 8 9 7 4 7 2 4 ...
                                    8 2 7 7 7 1 2 7 1 7 2 7 1 9 ]');
    general_test_classifier(cfy)


function test_classify_meta_feature_selection
%     % Oct 2016: got strange glibc error when using travis
%
%     ** glibc detected *** /usr/bin/octave-cli: double free or corruption (!prev): 0x0000000003570030 ***
%     ======= Backtrace: =========
%     /lib/x86_64-linux-gnu/libc.so.6(+0x7da26)[0x2abf370a4a26]
%     /usr/lib/x86_64-linux-gnu/liboctinterp.so.3(_ZN12symbol_table7cleanupEv+0x7d)[0x2abf3678001d]
%     /usr/lib/x86_64-linux-gnu/liboctinterp.so.3(+0xb486de)[0x2abf367b26de]
%     /usr/lib/x86_64-linux-gnu/liboctinterp.so.3(_Z17clean_up_and_exitib+0x17)[0x2abf367b4057]
%     /usr/lib/x86_64-linux-gnu/liboctinterp.so.3(octave_execute_interpreter+0xd1a)[0x2abf35f341ca]
%     /lib/x86_64-linux-gnu/libc.so.6(__libc_start_main+0xed)[0x2abf370487ed]
%
%     https://travis-ci.org/nno/CoSMoMVPA/builds/166411636
%
%     Fix seemded to use cosmo_classify_nn instead of cosmo_classify_lda
%     as child_classifier - but that may just be describing a symptom, not
%     the root cause.
    cfy=@cosmo_classify_meta_feature_selection;
    opt=struct();
    opt.child_classifier=@cosmo_classify_nn;
    opt.feature_selector=@cosmo_anova_feature_selector;
    opt.feature_selection_ratio_to_keep=.6;
    handle=get_predictor(cfy,opt);
    assert_predictions_equal(handle,[1 1 7 1 6 6 1 1 7 5 1 1 4 ...
                                      9 6 7 7 1 1 2 1 3 9 6 7 8 7 ]');
    general_test_classifier(cfy,opt)

function test_cosmo_meta_feature_selection_classifier
    % deprecated, so shows a warning
    warning_state=cosmo_warning();
    warning_state_resetter=onCleanup(@()cosmo_warning(warning_state));
    cosmo_warning('off');

    cfy=@cosmo_meta_feature_selection_classifier;
    opt=struct();
    opt.child_classifier=@cosmo_classify_nn;
    opt.feature_selector=@cosmo_anova_feature_selector;
    opt.feature_selection_ratio_to_keep=.6;
    handle=get_predictor(cfy,opt);
    assert_predictions_equal(handle,[1 1 7 1 6 6 1 1 7 5 1 1 4 ...
                                      9 6 7 7 1 1 2 1 3 9 6 7 8 7 ]');
    general_test_classifier(cfy,opt)

function test_classify_nn
    cfy=@cosmo_classify_nn;
    handle=get_predictor(cfy);
    assert_predictions_equal(handle,[1 3 6 8 6 6 8 7 7 5 7 7 4 ...
                                    9 7 7 7 7 1 2 7 1 5 6 7 1 9]');
    general_test_classifier(cfy)

function test_classify_knn
    cfy=@cosmo_classify_knn;
    opt=struct();
    opt.knn=2;

    handle=get_predictor(cfy,opt);
    assert_predictions_equal(handle,[7 3 6 1 7 2 8 7 9 4 7 5 7 1 ...
                                    7 7 7 8 1 6 1 1 9 5 8 1 9]');
    general_test_classifier(cfy,opt);

function test_classify_matlabsvm
    warning_state=cosmo_warning();
    cleaner=onCleanup(@()cosmo_warning(warning_state));
    cosmo_warning('off');

    cfy=@cosmo_classify_matlabsvm;
    handle=get_predictor(cfy);
    if ~cosmo_check_external('matlabsvm',false)
        assert_throws_illegal_input_exceptions(cfy);
        assertExceptionThrown(handle,'');
        notify_test_skipped('matlabsvm');
        return;
    end

    assert_predictions_equal(handle,[1 3 9 7 6 6 9 3 7 5 6 6 4 ...
                                    1 7 7 7 7 1 7 7 1 7 6 7 1 9]');
    general_test_classifier(cfy);

function test_classify_matlabcsvm
    warning_state=cosmo_warning();
    cleaner=onCleanup(@()cosmo_warning(warning_state));
    cosmo_warning('off');

    cfy=@cosmo_classify_matlabcsvm;
    handle=get_predictor(cfy);
    if ~cosmo_check_external('matlabcsvm',false)
        assert_throws_illegal_input_exceptions(cfy);
        assertExceptionThrown(handle,'');
        notify_test_skipped('matlabcsvm');
        return;
    end

    assert_predictions_equal(handle,[1 2 3 4 5 6 7 8 9 1 2 3 4 ...
                                    5 6 7 8 9 1 2 3 4 5 6 7 8 9]');
    general_test_classifier(cfy);

function test_classify_matlabsvm_2class
    warning_state=cosmo_warning();
    cleaner=onCleanup(@()cosmo_warning(warning_state));
    cosmo_warning('off');

    cfy=@cosmo_classify_matlabsvm_2class;
    handle=get_predictor(cfy);
    if ~cosmo_check_external('matlabsvm',false)
        assert_throws_illegal_input_exceptions(cfy);
        assertExceptionThrown(handle,'');
        notify_test_skipped('matlabsvm');
        return;
    end

    assertExceptionThrown(handle,''); % cannot deal with nine classes

    handle=get_predictor(cfy,struct(),2);

    assert_predictions_equal(handle,[1 2 2 2 1 2]');
    general_test_classifier(cfy);

     % test non-convergence
    aet=@(exc,varargin)assertExceptionThrown(@()...
                        cosmo_classify_matlabsvm_2class(varargin{:}),exc);
    opt=struct();
    opt.options.MaxIter=1;
    aet('',[0 0; 0 1; 1 0; 1 1; ],[1 2 2 1],NaN(2),opt);
    opt.tolkkt=struct();
    aet('stats:svmtrain:badTolKKT',...
                [0 0; 0 1; 1 0; 1 1; ],[1 2 2 1],NaN(2),opt);



function test_classify_libsvm_with_autoscale
    warning_state=cosmo_warning();
    cleaner=onCleanup(@()cosmo_warning(warning_state));
    cosmo_warning('off');

    cfy=@cosmo_classify_libsvm;
    opt=struct();
    opt.autoscale=true;
    handle=get_predictor(cfy,opt);
    if ~cosmo_check_external('libsvm',false)
        assertExceptionThrown(handle,'');
        notify_test_skipped('libsvm');
        return;
    end

    assert_predictions_equal(handle,[8 3 3 8 6 6 1 3 7 5 7 6 4 ...
                                    1 2 7 7 7 1 2 8 1 9 6 7 1 3 ]');
    general_test_classifier(cfy,opt);

function test_classify_libsvm_no_autoscale
    warning_state=cosmo_warning();
    cleaner=onCleanup(@()cosmo_warning(warning_state));
    cosmo_warning('off');

    cfy=@cosmo_classify_libsvm;
    opt=struct();
    opt.autoscale=false;
    handle=get_predictor(cfy,opt);
    if ~cosmo_check_external('libsvm',false)
        assertExceptionThrown(handle,'');
        cosmo_notify_test_skipped('libsvm');
        return;
    end

    assert_predictions_equal(handle,[8 3 3 8 6 6 1 3 7 5 7 6 4 ...
                                    1 2 7 7 7 1 2 8 1 9 6 7 1 3]');
    general_test_classifier(cfy,opt);



function test_classify_libsvm_t0
    % test with default (linear kernel) type t=0
    cfy=@cosmo_classify_libsvm;

    % with or without some of the default options; all should give the same
    % result
    params={{},...
            {'t','0'},...
            {'t',0},...
            {'t',0,'autoscale',false,'s','0','r',0,'c',1,'h',1}};
    n=numel(params);

    for k=1:n
        param=params{k};
        if isempty(param)
            opt=[];
            opt_struct=struct();
        else
            opt=cosmo_structjoin(param);
            opt_struct=opt;
        end
        handle=get_predictor(cfy,opt);
        if ~cosmo_check_external('libsvm',false)
            assertExceptionThrown(handle,'');
            notify_test_skipped('libsvm');
            return
        end

        assert_predictions_equal(handle,[8 3 3 8 6 6 1 3 7 5 7 6 4 ...
                                        1 2 7 7 7 1 2 8 1 9 6 7 1 3]');
        general_test_classifier(cfy,opt_struct)
    end

function test_classify_libsvm_t2
    cfy=@cosmo_classify_libsvm;
    opt=struct();
    opt.t=2;
    handle=get_predictor(cfy,opt);
    if ~cosmo_check_external('libsvm',false)
        assertExceptionThrown(handle,'');
        notify_test_skipped('libsvm');
        return
    end

    % libsvm uses autoscale by default
    assert_predictions_equal(handle,[1 3 6 8 6 6 8 3 7 5 7 5 4 ...
                                    5 7 7 7 8 1 2 8 1 5 5 7 1 7]');
    general_test_classifier(cfy,opt)

function test_classify_svm
    clear cosmo_check_external()
    cfy=@cosmo_classify_svm;
    handle=get_predictor(cfy);
    if ~cosmo_check_external('svm',false)
        assertExceptionThrown(handle,'');
        notify_test_skipped('svm');
        return;
    end

    % matlab and libsvm show slightly different results
    if cosmo_check_external('libsvm',false)
        pred=[8 3 3 8 6 6 1 3 7 5 7 6 4 1 2 7 7 7 1 2 8 1 9 6 7 1 3]';

        good_opt=struct();
        good_opt.svm='libsvm';
        bad_opt=struct();
        bad_opt.svm='matlabsvm';
    else
        % do not show warning message
        warning_state=cosmo_warning();
        cleaner=onCleanup(@()cosmo_warning(warning_state));
        cosmo_warning('off');

        pred=[1 3 9 7 6 6 9 3 7 5 6 6 4 1 7 7 7 7 1 7 7 1 7 6 7 1 9]';

        good_opt=struct();
        good_opt.svm='matlabsvm';
        bad_opt=struct();
        bad_opt.svm='libsvm';
    end

    assert_predictions_equal(handle,pred);
    general_test_classifier(cfy)
    general_test_classifier(cfy,good_opt)
    assertExceptionThrown(@()general_test_classifier(cfy,bad_opt),'');

function general_test_classifier(cfy_base,opt)
    if nargin<2
        cfy=cfy_base;
    else
        cfy=@(x,y,z)cfy_base(x,y,z,opt);
    end
    assert_chance_null_data(cfy);
    assert_above_chance_informative_data(cfy);
    assert_throws_expected_exceptions(cfy_base,cfy);

function assert_chance_null_data(cfy)
    assert_accuracy_in_range(cfy, 0, 0.3, 0.7);

function assert_above_chance_informative_data(cfy)
    assert_accuracy_in_range(cfy, 4, 0.8, 1);

function assert_accuracy_in_range(cfy, sigma, min_val, max_val)
    [tr_s,tr_t, te_s, te_t]=generate_informative_data(sigma);

    pred=cfy(tr_s, tr_t, te_s);
    acc=mean(pred==te_t);

    assertTrue(acc>=min_val);
    assertTrue(acc<=max_val);



function [tr_s,tr_t, te_s, te_t]=generate_informative_data(sigma)
    nclasses=2;
    nsamples_per_class=200;
    nsamples=nclasses*nsamples_per_class;
    nfeatures=10;

    common_s=randn(1,nfeatures)*sigma;
    targets=repmat((1:nclasses)',nsamples_per_class,1);

    tr_s=randn(nsamples,nfeatures);
    te_s=randn(nsamples,nfeatures);
    tr_t=targets;
    te_t=targets;

    for k=1:nfeatures
        msk=targets==(mod(k-1,nclasses)+1);
        tr_s(msk,k)=tr_s(msk,k)+common_s(:,k);
        te_s(msk,k)=te_s(msk,k)+common_s(:,k);
    end



function assert_throws_expected_exceptions(cfy_base,cfy)
    assert_throws_illegal_input_exceptions(cfy);
    assert_deals_with_empty_input(cfy_base,cfy);

function assert_throws_illegal_input_exceptions(cfy)
    warning_state=cosmo_warning();
    state_resetter=onCleanup(@()cosmo_warning(warning_state));
    cosmo_warning('off')

    assertExceptionThrown(@()cfy([1 2],[1;2],[1 2]),'')
    assertExceptionThrown(@()cfy([1;2],[1 2],[1 2]),'')
    assertExceptionThrown(@()cfy([1 2],[1 2],[1 2]),'')
    assertExceptionThrown(@()cfy([1 2],1,[1;2]),'')
    assertExceptionThrown(@()cfy([1;2],1,[1 2]),'')
    assertExceptionThrown(@()cfy([1 2; 3 4; 5 6],[1;1],[1 2]),'')
    assertExceptionThrown(@()cfy([1 2; 3 4; 5 6],[1;1;1],[1 2 3]),'')


function assert_deals_with_empty_input(cfy_base,cfy)
 % should pass
    non_one_class_classifiers={@cosmo_classify_matlabsvm_2class,...
                               @cosmo_classify_meta_feature_selection,...
                               @cosmo_meta_feature_selection_classifier};

    can_handle_single_class=~any(cellfun(@(x)isequal(cfy_base,x),...
                                    non_one_class_classifiers));

    if can_handle_single_class
        cfy([1 2; 3 4],[1;1],[1 2]);
        cfy([1 2; 3 4; 5 6],[1;1;1],[1 2]);
    end

    % no features, should still make prediction
    res=cfy(zeros(4,0),[1 1 2 2]',zeros(2,0));
    assertEqual(size(res),[2 1]);
    assertTrue(all(res==1 | res==2));

    res2=cfy(zeros(4,0),[1 1 2 2]',zeros(2,0));
    assertEqual(res,res2);




function handle=get_predictor(cfy,opt,nclasses)
    if nargin<3
        nclasses=9;
    end
    if nargin<2 || isempty(opt)
        opt_arg={};
    else
        opt_arg={opt};
    end
    [tr_samples,tr_targets,te_samples]=generate_data(nclasses);
    handle=@()cfy(tr_samples,tr_targets,te_samples,opt_arg{:});


function assert_predictions_equal(handle, targets)
    pred=handle();
    assertEqual(pred, targets);

    % test caching, if implemented
    pred2=handle();
    assertEqual(pred,pred2);


function [tr_samples,tr_targets,te_samples]=generate_data(nclasses)
    ds=cosmo_synthetic_dataset('ntargets',nclasses,...
                                'nchunks',6);
    te_msk=ds.sa.chunks<=3;
    tr_msk=~te_msk;
    tr_targets=ds.sa.targets(tr_msk);
    tr_samples=ds.samples(tr_msk,:);
    te_samples=ds.samples(te_msk,:);

function notify_test_skipped(external)
    assertTrue(cosmo_skip_test_if_no_external(external));