function test_suite = test_classify
% tests for cosmo_classify_* functions
%
% # For CoSMoMVPA's copyright information and license terms, #
% # see the COPYING file distributed with CoSMoMVPA. #
try % assignment of 'localfunctions' is necessary in Matlab >= 2016
test_functions = localfunctions();
catch % no problem; early Matlab versions can use initTestSuite fine
end
initTestSuite;
function test_general_classifiers_strong_signal()
% note: part of these checks are also implemented in
% general_test_classifier (included below)
if cosmo_wtf('is_matlab')
csvm_external = '!fitcsvm';
else
non_existing_external = '!';
csvm_external = non_existing_external;
end
cfy_cell = {{@cosmo_classify_lda, ''}, ...
{@cosmo_classify_nn, ''}, ...
{@cosmo_classify_naive_bayes, ''}, ...
{@cosmo_classify_libsvm, 'libsvm'}, ...
{@cosmo_classify_matlabsvm, 'matlabsvm'}, ...
{@cosmo_classify_matlabcsvm, csvm_external}};
n_cfy = numel(cfy_cell);
no_information = 0;
strong_information = 2 + rand();
for sigma = [no_information, strong_information]
[tr_s, tr_t, te_s, te_t] = generate_informative_data(sigma);
for k = 1:n_cfy
cfy = cfy_cell{k}{1};
predictor_func = @()cfy(tr_s, tr_t, te_s);
external = cfy_cell{k}{2};
if ~strcmp(external, '')
if ~cosmo_check_external(external, false)
assertExceptionThrown(predictor_func, '*');
continue
end
end
pred = cfy(tr_s, tr_t, te_s);
acc = mean(pred == te_t);
assertEqual(numel(pred), 400);
if sigma == no_information
assertTrue(0.4 < acc && acc < 0.6);
elseif sigma == strong_information
assertTrue(acc > 0.9);
else
assertFalse(true);
end
end
end
function test_classify_lda
cfy = @cosmo_classify_lda;
handle = get_predictor(cfy);
assert_predictions_equal(handle, [1 3 9 8 5 6 8 9 7 5 7 5 4 ...
9 2 7 7 7 1 2 1 1 7 6 7 1 7]');
general_test_classifier(cfy);
aet = @(varargin)assertExceptionThrown(@() ...
cosmo_classify_lda(varargin{:}), '');
% size mismatch
x = randn(5, 3);
y = randn(2, 3);
aet(x, [1 1 1 2 2], y);
% too many features
x = zeros(1, 1e4);
aet(x, 1, 1e4);
function test_classify_naive_bayes
cfy = @cosmo_classify_naive_bayes;
handle = get_predictor(cfy);
assert_predictions_equal(handle, [1 7 3 9 2 2 8 9 7 4 7 2 4 ...
8 2 7 7 7 1 2 7 1 7 2 7 1 9]');
general_test_classifier(cfy);
function test_classify_meta_feature_selection
% % Oct 2016: got strange glibc error when using travis
%
% ** glibc detected *** /usr/bin/octave-cli: double free or corruption (!prev): 0x0000000003570030 ***
% ======= Backtrace: =========
% /lib/x86_64-linux-gnu/libc.so.6(+0x7da26)[0x2abf370a4a26]
% /usr/lib/x86_64-linux-gnu/liboctinterp.so.3(_ZN12symbol_table7cleanupEv+0x7d)[0x2abf3678001d]
% /usr/lib/x86_64-linux-gnu/liboctinterp.so.3(+0xb486de)[0x2abf367b26de]
% /usr/lib/x86_64-linux-gnu/liboctinterp.so.3(_Z17clean_up_and_exitib+0x17)[0x2abf367b4057]
% /usr/lib/x86_64-linux-gnu/liboctinterp.so.3(octave_execute_interpreter+0xd1a)[0x2abf35f341ca]
% /lib/x86_64-linux-gnu/libc.so.6(__libc_start_main+0xed)[0x2abf370487ed]
%
% https://travis-ci.org/nno/CoSMoMVPA/builds/166411636
%
% Fix seemded to use cosmo_classify_nn instead of cosmo_classify_lda
% as child_classifier - but that may just be describing a symptom, not
% the root cause.
cfy = @cosmo_classify_meta_feature_selection;
opt = struct();
opt.child_classifier = @cosmo_classify_nn;
opt.feature_selector = @cosmo_anova_feature_selector;
opt.feature_selection_ratio_to_keep = .6;
handle = get_predictor(cfy, opt);
assert_predictions_equal(handle, [1 1 7 1 6 6 1 1 7 5 1 1 4 ...
9 6 7 7 1 1 2 1 3 9 6 7 8 7]');
general_test_classifier(cfy, opt);
function test_cosmo_meta_feature_selection_classifier
% deprecated, so shows a warning
warning_state = cosmo_warning();
warning_state_resetter = onCleanup(@()cosmo_warning(warning_state));
cosmo_warning('off');
cfy = @cosmo_meta_feature_selection_classifier;
opt = struct();
opt.child_classifier = @cosmo_classify_nn;
opt.feature_selector = @cosmo_anova_feature_selector;
opt.feature_selection_ratio_to_keep = .6;
handle = get_predictor(cfy, opt);
assert_predictions_equal(handle, [1 1 7 1 6 6 1 1 7 5 1 1 4 ...
9 6 7 7 1 1 2 1 3 9 6 7 8 7]');
general_test_classifier(cfy, opt);
function test_classify_nn
cfy = @cosmo_classify_nn;
handle = get_predictor(cfy);
assert_predictions_equal(handle, [1 3 6 8 6 6 8 7 7 5 7 7 4 ...
9 7 7 7 7 1 2 7 1 5 6 7 1 9]');
general_test_classifier(cfy);
function test_classify_knn
cfy = @cosmo_classify_knn;
opt = struct();
opt.knn = 2;
handle = get_predictor(cfy, opt);
assert_predictions_equal(handle, [7 3 6 1 7 2 8 7 9 4 7 5 7 1 ...
7 7 7 8 1 6 1 1 9 5 8 1 9]');
general_test_classifier(cfy, opt);
function test_classify_matlabsvm
warning_state = cosmo_warning();
cleaner = onCleanup(@()cosmo_warning(warning_state));
cosmo_warning('off');
cfy = @cosmo_classify_matlabsvm;
handle = get_predictor(cfy);
if ~cosmo_check_external('matlabsvm', false)
assert_throws_illegal_input_exceptions(cfy);
assertExceptionThrown(handle, '');
notify_test_skipped('matlabsvm');
return
end
assert_predictions_equal(handle, [1 3 9 7 6 6 9 3 7 5 6 6 4 ...
1 7 7 7 7 1 7 7 1 7 6 7 1 9]');
general_test_classifier(cfy);
function test_classify_matlabcsvm
warning_state = cosmo_warning();
cleaner = onCleanup(@()cosmo_warning(warning_state));
cosmo_warning('off');
cfy = @cosmo_classify_matlabcsvm;
handle = get_predictor(cfy);
if ~cosmo_check_external('matlabcsvm', false) || ...
~cosmo_check_external('!fitcsvm', false)
assert_throws_illegal_input_exceptions(cfy);
assertExceptionThrown(handle);
notify_test_skipped('matlabcsvm');
return
end
assert_predictions_equal(handle, [1 2 3 4 5 6 7 8 9 1 2 3 4 ...
5 6 7 8 9 1 2 3 4 5 6 7 8 9]');
general_test_classifier(cfy);
function test_classify_matlabsvm_2class
warning_state = cosmo_warning();
cleaner = onCleanup(@()cosmo_warning(warning_state));
cosmo_warning('off');
cfy = @cosmo_classify_matlabsvm_2class;
handle = get_predictor(cfy);
if ~cosmo_check_external('matlabsvm', false)
assert_throws_illegal_input_exceptions(cfy);
assertExceptionThrown(handle, '');
notify_test_skipped('matlabsvm');
return
end
assertExceptionThrown(handle, ''); % cannot deal with nine classes
handle = get_predictor(cfy, struct(), 2);
assert_predictions_equal(handle, [1 2 2 2 1 2]');
general_test_classifier(cfy);
% test non-convergence
aet = @(exc, varargin)assertExceptionThrown(@() ...
cosmo_classify_matlabsvm_2class(varargin{:}), exc);
opt = struct();
opt.options.MaxIter = 1;
aet('', [0 0; 0 1; 1 0; 1 1], [1 2 2 1], NaN(2), opt);
opt.tolkkt = struct();
aet('stats:svmtrain:badTolKKT', ...
[0 0; 0 1; 1 0; 1 1], [1 2 2 1], NaN(2), opt);
function test_classify_libsvm_with_autoscale
warning_state = cosmo_warning();
cleaner = onCleanup(@()cosmo_warning(warning_state));
cosmo_warning('off');
cfy = @cosmo_classify_libsvm;
opt = struct();
opt.autoscale = true;
handle = get_predictor(cfy, opt);
if ~cosmo_check_external('libsvm', false)
assertExceptionThrown(handle, '');
notify_test_skipped('libsvm');
return
end
assert_predictions_equal(handle, [8 3 3 8 6 6 1 3 7 5 7 6 4 ...
1 2 7 7 7 1 2 8 1 9 6 7 1 3]');
general_test_classifier(cfy, opt);
function test_classify_libsvm_no_autoscale
warning_state = cosmo_warning();
cleaner = onCleanup(@()cosmo_warning(warning_state));
cosmo_warning('off');
cfy = @cosmo_classify_libsvm;
opt = struct();
opt.autoscale = false;
handle = get_predictor(cfy, opt);
if ~cosmo_check_external('libsvm', false)
assertExceptionThrown(handle, '');
cosmo_notify_test_skipped('libsvm');
return
end
assert_predictions_equal(handle, [8 3 3 8 6 6 1 3 7 5 7 6 4 ...
1 2 7 7 7 1 2 8 1 9 6 7 1 3]');
general_test_classifier(cfy, opt);
function test_classify_libsvm_t0
% test with default (linear kernel) type t=0
cfy = @cosmo_classify_libsvm;
% with or without some of the default options; all should give the same
% result
params = {{}, ...
{'t', '0'}, ...
{'t', 0}, ...
{'t', 0, 'autoscale', false, 's', '0', 'r', 0, 'c', 1, 'h', 1}};
n = numel(params);
for k = 1:n
param = params{k};
if isempty(param)
opt = [];
opt_struct = struct();
else
opt = cosmo_structjoin(param);
opt_struct = opt;
end
handle = get_predictor(cfy, opt);
if ~cosmo_check_external('libsvm', false)
assertExceptionThrown(handle, '');
notify_test_skipped('libsvm');
return
end
assert_predictions_equal(handle, [8 3 3 8 6 6 1 3 7 5 7 6 4 ...
1 2 7 7 7 1 2 8 1 9 6 7 1 3]');
general_test_classifier(cfy, opt_struct);
end
function test_classify_libsvm_t2
cfy = @cosmo_classify_libsvm;
opt = struct();
opt.t = 2;
handle = get_predictor(cfy, opt);
if ~cosmo_check_external('libsvm', false)
assertExceptionThrown(handle, '');
notify_test_skipped('libsvm');
return
end
% libsvm uses autoscale by default
assert_predictions_equal(handle, [1 3 6 8 6 6 8 3 7 5 7 5 4 ...
5 7 7 7 8 1 2 8 1 5 5 7 1 7]');
general_test_classifier(cfy, opt);
function test_classify_svm
clear cosmo_check_external();
cfy = @cosmo_classify_svm;
handle = get_predictor(cfy);
if ~cosmo_check_external('svm', false)
assertExceptionThrown(handle, '');
notify_test_skipped('svm');
return
end
% matlab and libsvm show slightly different results
if cosmo_check_external('libsvm', false)
pred = [8 3 3 8 6 6 1 3 7 5 7 6 4 1 2 7 7 7 1 2 8 1 9 6 7 1 3]';
good_opt = struct();
good_opt.svm = 'libsvm';
bad_opt = struct();
bad_opt.svm = 'matlabsvm';
else
% do not show warning message
warning_state = cosmo_warning();
cleaner = onCleanup(@()cosmo_warning(warning_state));
cosmo_warning('off');
pred = [1 3 9 7 6 6 9 3 7 5 6 6 4 1 7 7 7 7 1 7 7 1 7 6 7 1 9]';
good_opt = struct();
good_opt.svm = 'matlabsvm';
bad_opt = struct();
bad_opt.svm = 'libsvm';
end
assert_predictions_equal(handle, pred);
general_test_classifier(cfy);
general_test_classifier(cfy, good_opt);
assertExceptionThrown(@()general_test_classifier(cfy, bad_opt), '');
function general_test_classifier(cfy_base, opt)
if nargin < 2
cfy = cfy_base;
else
cfy = @(x, y, z)cfy_base(x, y, z, opt);
end
assert_chance_null_data(cfy);
assert_above_chance_informative_data(cfy);
assert_throws_expected_exceptions(cfy_base, cfy);
function assert_chance_null_data(cfy)
assert_accuracy_in_range(cfy, 0, 0.3, 0.7);
function assert_above_chance_informative_data(cfy)
assert_accuracy_in_range(cfy, 4, 0.8, 1);
function assert_accuracy_in_range(cfy, sigma, min_val, max_val)
[tr_s, tr_t, te_s, te_t] = generate_informative_data(sigma);
pred = cfy(tr_s, tr_t, te_s);
acc = mean(pred == te_t);
assertTrue(acc >= min_val);
assertTrue(acc <= max_val);
function [tr_s, tr_t, te_s, te_t] = generate_informative_data(sigma)
nclasses = 2;
nsamples_per_class = 200;
nsamples = nclasses * nsamples_per_class;
nfeatures = 10;
common_s = randn(1, nfeatures) * sigma;
targets = repmat((1:nclasses)', nsamples_per_class, 1);
tr_s = randn(nsamples, nfeatures);
te_s = randn(nsamples, nfeatures);
tr_t = targets;
te_t = targets;
for k = 1:nfeatures
msk = targets == (mod(k - 1, nclasses) + 1);
tr_s(msk, k) = tr_s(msk, k) + common_s(:, k);
te_s(msk, k) = te_s(msk, k) + common_s(:, k);
end
function assert_throws_expected_exceptions(cfy_base, cfy)
assert_throws_illegal_input_exceptions(cfy);
assert_deals_with_empty_input(cfy_base, cfy);
function assert_throws_illegal_input_exceptions(cfy)
warning_state = cosmo_warning();
state_resetter = onCleanup(@()cosmo_warning(warning_state));
cosmo_warning('off');
assertExceptionThrown(@()cfy([1 2], [1; 2], [1 2]), '');
assertExceptionThrown(@()cfy([1; 2], [1 2], [1 2]), '');
assertExceptionThrown(@()cfy([1 2], [1 2], [1 2]), '');
assertExceptionThrown(@()cfy([1 2], 1, [1; 2]), '');
assertExceptionThrown(@()cfy([1; 2], 1, [1 2]), '');
assertExceptionThrown(@()cfy([1 2; 3 4; 5 6], [1; 1], [1 2]), '');
assertExceptionThrown(@()cfy([1 2; 3 4; 5 6], [1; 1; 1], [1 2 3]), '');
function assert_deals_with_empty_input(cfy_base, cfy)
% should pass
non_one_class_classifiers = {@cosmo_classify_matlabsvm_2class, ...
@cosmo_classify_meta_feature_selection, ...
@cosmo_meta_feature_selection_classifier, ...
@cosmo_classify_matlabcsvm};
if cosmo_check_external('octave_pkg_statistics_libsvm', false)
non_one_class_classifiers = [non_one_class_classifiers, ...
{@cosmo_classify_libsvm, ...
@cosmo_classify_svm}];
end
can_handle_single_class = ~any(cellfun(@(x)isequal(cfy_base, x), ...
non_one_class_classifiers));
if can_handle_single_class
cfy([1 2; 3 4], [1; 1], [1 2]);
cfy([1 2; 3 4; 5 6], [1; 1; 1], [1 2]);
end
% no features, should still make prediction
res = cfy(zeros(4, 0), [1 1 2 2]', zeros(2, 0));
assertEqual(size(res), [2 1]);
assertTrue(all(res == 1 | res == 2));
res2 = cfy(zeros(4, 0), [1 1 2 2]', zeros(2, 0));
assertEqual(res, res2);
function handle = get_predictor(cfy, opt, nclasses)
if nargin < 3
nclasses = 9;
end
if nargin < 2 || isempty(opt)
opt_arg = {};
else
opt_arg = {opt};
end
[tr_samples, tr_targets, te_samples] = generate_data(nclasses);
handle = @()cfy(tr_samples, tr_targets, te_samples, opt_arg{:});
function assert_predictions_equal(handle, targets)
pred = handle();
assertEqual(pred, targets);
% test caching, if implemented
pred2 = handle();
assertEqual(pred, pred2);
function [tr_samples, tr_targets, te_samples] = generate_data(nclasses)
ds = cosmo_synthetic_dataset('ntargets', nclasses, ...
'nchunks', 6);
te_msk = ds.sa.chunks <= 3;
tr_msk = ~te_msk;
tr_targets = ds.sa.targets(tr_msk);
tr_samples = ds.samples(tr_msk, :);
te_samples = ds.samples(te_msk, :);
function notify_test_skipped(external)
assertTrue(cosmo_skip_test_if_no_external(external));