function predicted=cosmo_classify_libsvm(samples_train, targets_train, samples_test, opt)
% libsvm-based SVM classifier
%
% predicted=cosmo_classify_libsvm(samples_train, targets_train, samples_test, opt)
%
% Inputs
% samples_train PxR training data for P samples and R features
% targets_train Px1 training data classes
% samples_test QxR test data
% opt (optional) struct with options for svmtrain
% .autoscale If true (default), z-scoring is done on the training
% set; the test set is z-scored using the mean and std
% estimates from the training set.
% ? any option supported by libsvm's svmtrain.
%
% Output
% predicted Qx1 predicted data classes for samples_test
%
% Notes:
% - this function requires libsvm version 3.18 or later:
% https://github.com/cjlin1/libsvm
% - by default a linear kernel is used ('-t 0')
% - this function uses LIBSVM's svmtrain function, which has the same
% name as matlab's builtin version. Use of this function is not
% supported when matlab's svmtrain precedes in the matlab path; in
% that case, adjust the path or use cosmo_classify_matlabsvm instead.
% - for a guide on svm classification, see
% http://www.csie.ntu.edu.tw/~cjlin/papers/guide/guide.pdf
% - By default this function performs z-scoring of the data. To switch
% this off, set 'autoscale' to false
% - cosmo_crossvalidate and cosmo_crossvalidation_measure
% provide an option 'normalization' to perform data scaling
%
%
% See also svmtrain, svmclassify, cosmo_classify_svm,
% cosmo_classify_matlabsvm
%
% # For CoSMoMVPA's copyright information and license terms, #
% # see the COPYING file distributed with CoSMoMVPA. #
if nargin<4
opt=struct();
end
% support repeated testing on different data after training every time
% on the same data. This is achieved by caching the training data
% and associated model
persistent cached_targets_train;
persistent cached_samples_train;
persistent cached_opt;
persistent cached_model;
cache_limit=1e5; % avoid caching huge training sets
if isequal(cached_targets_train, targets_train) && ...
isequal(cached_opt, opt) && ...
numel(samples_train) < cache_limit && ...
isequal(cached_samples_train, samples_train)
% use cache
model=cached_model;
else
model=train(samples_train,targets_train,opt);
% store model
cached_targets_train=targets_train;
cached_samples_train=samples_train;
cached_opt=opt;
cached_model=model;
end
predicted=test(model, samples_test);
function model=train(samples_train,targets_train,opt)
[ntrain, nfeatures]=size(samples_train);
model=struct();
model.nfeatures=nfeatures;
model.normalize=[]; % off by default
% check input size
ntrain_=numel(targets_train);
if ntrain_~=ntrain
error('illegal input size');
end
% construct options string for svmtrain
opt_str=libsvm_opt2str(opt);
% auto-scale is the default
autoscale= ~isstruct(opt) || ...
~isfield(opt,'autoscale') || ...
isempty(opt.autoscale) || ...
opt.autoscale;
% perform autoscale if necessary
if autoscale
[samples_train, params]=cosmo_normalize(samples_train, ...
'zscore', 1);
model.normalize=params;
end
train_func=@()svmtrain(targets_train(:), ...
samples_train, ...
opt_str);
model.libsvm_model=eval_with_check_external(train_func);
function output=eval_with_check_external(func)
% Evaluates func() in try-catch block. If it fails, this may be due
% to missing libsvm and/or libsvm conflicting with Matlab's svm.
% Therefore first cosmo_check_external is called, which will give an
% informative error message if that is the case. Otherwise the original
% message is shown.
try
output=func();
catch
cosmo_check_external('libsvm');
rethrow(lasterror());
end
function predicted=test(model, samples_test)
[ntest, nfeatures]=size(samples_test);
if nfeatures~=model.nfeatures
error(['Number of features in train set (%d) and '...
'test set (%d) do not match'],...
model.nfeatures,nfeatures);
end
if ~isempty(model.normalize)
samples_test=cosmo_normalize(samples_test, model.normalize);
end
test_opt_str='-q'; % quiet (no output)
test_func=@()svmpredict(NaN(ntest,1), samples_test, ...
model.libsvm_model, test_opt_str);
predicted=eval_with_check_external(test_func);
function opt_str=libsvm_opt2str(opt)
persistent cached_opt
persistent cached_opt_str
if ~isequal(opt, cached_opt)
default_opt={'t','q';...
'0',''};
n_default=size(default_opt,2);
libsvm_opt_keys={'s','t','d','g','r','c','n','p',...
'm','e','h','n','wi','v'};
opt_struct=cosmo_structjoin(opt);
keys=intersect(fieldnames(opt_struct),libsvm_opt_keys);
n_keys=numel(keys);
use_defaults=true(1,n_default);
libsvm_opt=cell(2,n_keys);
for k=1:n_keys
key=keys{k};
default_pos=find(cosmo_match(default_opt(1,:),(key)),1);
if ~isempty(default_pos)
use_defaults(default_pos)=false;
end
value=opt_struct.(key);
if isnumeric(value)
value=sprintf('%d',value);
end
libsvm_opt(:,k)={key;value};
end
all_opt=[libsvm_opt default_opt(:, use_defaults)];
cached_opt_str=sprintf('-%s %s ', all_opt{:});
cached_opt=opt;
end
opt_str=cached_opt_str;