cosmo classify libsvm skl

function predicted = cosmo_classify_libsvm(samples_train, targets_train, samples_test, opt)
    % libsvm-based SVM classifier
    %
    % predicted=cosmo_classify_libsvm(samples_train, targets_train, samples_test, opt)
    %
    % Inputs
    %   samples_train      PxR training data for P samples and R features
    %   targets_train      Px1 training data classes
    %   samples_test       QxR test data
    %   opt                (optional) struct with options for svmtrain
    %     .autoscale       If true (default), z-scoring is done on the training
    %                      set; the test set is z-scored using the mean and std
    %                      estimates from the training set.
    %     ?                any option supported by libsvm's svmtrain.
    %
    % Output
    %   predicted          Qx1 predicted data classes for samples_test
    %
    % Notes:
    %  - this function requires libsvm version 3.18 or later:
    %    https://github.com/cjlin1/libsvm
    %  - by default a linear kernel is used ('-t 0')
    %  - this function uses LIBSVM's svmtrain function, which has the same
    %    name as matlab's builtin version. Use of this function is not
    %    supported when matlab's svmtrain precedes in the matlab path; in
    %    that case, adjust the path or use cosmo_classify_matlabsvm instead.
    %  - for a guide on svm classification, see
    %      http://www.csie.ntu.edu.tw/~cjlin/papers/guide/guide.pdf
    %  - By default this function performs z-scoring of the data. To switch
    %    this off, set 'autoscale' to false
    %  - cosmo_crossvalidate and cosmo_crossvalidation_measure
    %    provide an option 'normalization' to perform data scaling
    %
    %
    % See also svmtrain, svmclassify, cosmo_classify_svm,
    %          cosmo_classify_matlabsvm
    %
    % #   For CoSMoMVPA's copyright information and license terms,   #
    % #   see the COPYING file distributed with CoSMoMVPA.           #

    if nargin < 4
        opt = struct();
    end

    % support repeated testing on different data after training every time
    % on the same data. This is achieved by caching the training data
    % and associated model
    persistent cached_targets_train
    persistent cached_samples_train
    persistent cached_opt
    persistent cached_model

    cache_limit = 1e5; % avoid caching huge training sets

    if isequal(cached_targets_train, targets_train) && ...
            isequal(cached_opt, opt) && ...
            numel(samples_train) < cache_limit && ...
            isequal(cached_samples_train, samples_train)
        % use cache
        model = cached_model;
    else
        model = train(samples_train, targets_train, opt);

        % store model
        cached_targets_train = targets_train;
        cached_samples_train = samples_train;
        cached_opt = opt;
        cached_model = model;
    end

    predicted = test(model, samples_test);

function model = train(samples_train, targets_train, opt)
    [ntrain, nfeatures] = size(samples_train);

    model = struct();
    model.nfeatures = nfeatures;
    model.normalize = []; % off by default

    % check input size
    ntrain_ = numel(targets_train);
    if ntrain_ ~= ntrain
        error('illegal input size');
    end

    % construct options string for svmtrain
    opt_str = libsvm_opt2str(opt);

    % auto-scale is the default
    autoscale = ~isstruct(opt) || ...
                ~isfield(opt, 'autoscale') || ...
                isempty(opt.autoscale) || ...
                opt.autoscale;

    % perform autoscale if necessary
    if autoscale
        [samples_train, params] = cosmo_normalize(samples_train, ...
                                                  'zscore', 1);
        model.normalize = params;
    end
    train_func = @()svmtrain(targets_train(:), ...
                             samples_train, ...
                             opt_str);

    model.libsvm_model = eval_with_check_external(train_func);

function output = eval_with_check_external(func)
    % Evaluates func() in try-catch block. If it fails, this may be due
    % to missing libsvm and/or libsvm conflicting with Matlab's svm.
    % Therefore first cosmo_check_external is called, which will give an
    % informative error message if that is the case. Otherwise the original
    % message is shown.

    try
        output = func();
    catch
        cosmo_check_external('libsvm');
        rethrow(lasterror());
    end

function predicted = test(model, samples_test)
    [ntest, nfeatures] = size(samples_test);
    if nfeatures ~= model.nfeatures
        error(['Number of features in train set (%d) and '...
               'test set (%d) do not match'], ...
              model.nfeatures, nfeatures);
    end

    if ~isempty(model.normalize)
        samples_test = cosmo_normalize(samples_test, model.normalize);
    end

    test_opt_str = '-q'; % quiet (no output)
    test_func = @()svmpredict(NaN(ntest, 1), samples_test, ...
                              model.libsvm_model, test_opt_str);
    predicted = eval_with_check_external(test_func);

function opt_str = libsvm_opt2str(opt)
    persistent cached_opt
    persistent cached_opt_str

    if ~isequal(opt, cached_opt)
        default_opt = {'t', 'q'; ...
                       '0', ''};
        n_default = size(default_opt, 2);

        libsvm_opt_keys = {'s', 't', 'd', 'g', 'r', 'c', 'n', 'p', ...
                           'm', 'e', 'h', 'n', 'wi', 'v'};
        opt_struct = cosmo_structjoin(opt);

        keys = intersect(fieldnames(opt_struct), libsvm_opt_keys);
        n_keys = numel(keys);

        use_defaults = true(1, n_default);

        libsvm_opt = cell(2, n_keys);
        for k = 1:n_keys
            key = keys{k};
            default_pos = find(cosmo_match(default_opt(1, :), key), 1);

            if ~isempty(default_pos)
                use_defaults(default_pos) = false;
            end

            value = opt_struct.(key);
            if isnumeric(value)
                value = sprintf('%d', value);
            end

            libsvm_opt(:, k) = {key; value};
        end

        all_opt = [libsvm_opt default_opt(:, use_defaults)];

        cached_opt_str = sprintf('-%s %s ', all_opt{:});
        cached_opt = opt;
    end

    opt_str = cached_opt_str;