cosmo classify libsvm skl

function predicted=cosmo_classify_libsvm(samples_train, targets_train, samples_test, opt)
% libsvm-based SVM classifier
%
% predicted=cosmo_classify_libsvm(samples_train, targets_train, samples_test, opt)
%
% Inputs
%   samples_train      PxR training data for P samples and R features
%   targets_train      Px1 training data classes
%   samples_test       QxR test data
%   opt                (optional) struct with options for svmtrain
%     .autoscale       If true (default), z-scoring is done on the training
%                      set; the test set is z-scored using the mean and std
%                      estimates from the training set.
%     ?                any option supported by libsvm's svmtrain.
%
% Output
%   predicted          Qx1 predicted data classes for samples_test
%
% Notes:
%  - this function requires libsvm version 3.18 or later:
%    https://github.com/cjlin1/libsvm
%  - by default a linear kernel is used ('-t 0')
%  - this function uses LIBSVM's svmtrain function, which has the same
%    name as matlab's builtin version. Use of this function is not
%    supported when matlab's svmtrain precedes in the matlab path; in
%    that case, adjust the path or use cosmo_classify_matlabsvm instead.
%  - for a guide on svm classification, see
%      http://www.csie.ntu.edu.tw/~cjlin/papers/guide/guide.pdf
%  - By default this function performs z-scoring of the data. To switch
%    this off, set 'autoscale' to false
%  - cosmo_crossvalidate and cosmo_crossvalidation_measure
%    provide an option 'normalization' to perform data scaling
%
%
% See also svmtrain, svmclassify, cosmo_classify_svm,
%          cosmo_classify_matlabsvm
%
% #   For CoSMoMVPA's copyright information and license terms,   #
% #   see the COPYING file distributed with CoSMoMVPA.           #

    if nargin<4
        opt=struct();
    end

    % support repeated testing on different data after training every time
    % on the same data. This is achieved by caching the training data
    % and associated model
    persistent cached_targets_train;
    persistent cached_samples_train;
    persistent cached_opt;
    persistent cached_model;

    cache_limit=1e5; % avoid caching huge training sets

    if isequal(cached_targets_train, targets_train) && ...
            isequal(cached_opt, opt) && ...
            numel(samples_train) < cache_limit && ...
            isequal(cached_samples_train, samples_train)
        % use cache
        model=cached_model;
    else
        model=train(samples_train,targets_train,opt);

                % store model
        cached_targets_train=targets_train;
        cached_samples_train=samples_train;
        cached_opt=opt;
        cached_model=model;
    end

    predicted=test(model, samples_test);

function model=train(samples_train,targets_train,opt)
    [ntrain, nfeatures]=size(samples_train);

    model=struct();
    model.nfeatures=nfeatures;
    model.normalize=[]; % off by default

    % check input size
    ntrain_=numel(targets_train);
    if ntrain_~=ntrain
        error('illegal input size');
    end

    % construct options string for svmtrain
    opt_str=libsvm_opt2str(opt);

    % auto-scale is the default
    autoscale= ~isstruct(opt) || ...
                ~isfield(opt,'autoscale') || ...
                isempty(opt.autoscale) || ...
                opt.autoscale;

    % perform autoscale if necessary
    if autoscale
        [samples_train, params]=cosmo_normalize(samples_train, ...
                                                    'zscore', 1);
        model.normalize=params;
    end
    train_func=@()svmtrain(targets_train(:), ...
                            samples_train, ...
                            opt_str);

    model.libsvm_model=eval_with_check_external(train_func);


function output=eval_with_check_external(func)
% Evaluates func() in try-catch block. If it fails, this may be due
% to missing libsvm and/or libsvm conflicting with Matlab's svm.
% Therefore first cosmo_check_external is called, which will give an
% informative error message if that is the case. Otherwise the original
% message is shown.

    try
        output=func();
    catch
        cosmo_check_external('libsvm');
        rethrow(lasterror());
    end


function predicted=test(model, samples_test)
    [ntest, nfeatures]=size(samples_test);
    if nfeatures~=model.nfeatures
        error(['Number of features in train set (%d) and '...
                    'test set (%d) do not match'],...
                    model.nfeatures,nfeatures);
    end

    if ~isempty(model.normalize)
        samples_test=cosmo_normalize(samples_test, model.normalize);
    end

    test_opt_str='-q'; % quiet (no output)
    test_func=@()svmpredict(NaN(ntest,1), samples_test, ...
                            model.libsvm_model, test_opt_str);
    predicted=eval_with_check_external(test_func);




function opt_str=libsvm_opt2str(opt)
    persistent cached_opt
    persistent cached_opt_str

    if ~isequal(opt, cached_opt)
        default_opt={'t','q';...
                        '0',''};
        n_default=size(default_opt,2);

        libsvm_opt_keys={'s','t','d','g','r','c','n','p',...
                         'm','e','h','n','wi','v'};
        opt_struct=cosmo_structjoin(opt);

        keys=intersect(fieldnames(opt_struct),libsvm_opt_keys);
        n_keys=numel(keys);

        use_defaults=true(1,n_default);

        libsvm_opt=cell(2,n_keys);
        for k=1:n_keys
            key=keys{k};
            default_pos=find(cosmo_match(default_opt(1,:),(key)),1);

            if ~isempty(default_pos)
                use_defaults(default_pos)=false;
            end

            value=opt_struct.(key);
            if isnumeric(value)
                value=sprintf('%d',value);
            end

            libsvm_opt(:,k)={key;value};
        end

        all_opt=[libsvm_opt default_opt(:, use_defaults)];

        cached_opt_str=sprintf('-%s %s ', all_opt{:});
        cached_opt=opt;
    end

    opt_str=cached_opt_str;