cosmo classify matlabcsvm skl

function predicted=cosmo_classify_matlabcsvm(samples_train, targets_train, samples_test, opt)
% svm classifier wrapper (around fitcsvm)
%
% predicted=cosmo_classify_matlabcsvm(samples_train, targets_train, samples_test, opt)
%
% Inputs:
%   samples_train      PxR training data for P samples and R features
%   targets_train      Px1 training data classes
%   samples_test       QxR test data
%   opt                struct with options. supports any option that
%                      fitcsvm supports
%
% Output:
%   predicted          Qx1 predicted data classes for samples_test
%
% Notes:
%  - this function uses Matlab's builtin fitcsvm function, which was the
%    successor of svmtrain.
%  - Matlab's SVM classifier is rather slow, especially for multi-class
%    data (more than two classes). When classification takes a long time,
%    consider using libsvm.
%  - for a guide on svm classification, see
%      http://www.csie.ntu.edu.tw/~cjlin/papers/guide/guide.pdf
%    Note that cosmo_crossvalidate and cosmo_crossvalidation_measure
%    provide an option 'normalization' to perform data scaling
%  - As of Matlab 2017a (maybe earlier), Matlab gives the warning that
%      'svmtrain will be removed in a future release. Use fitcsvm instead.'
%    however fitcsvm gives different results than svmtrain; as a result
%    cosmo_classify_matlabcsvm gives different results than
%    cosmo_classify_matlabsvm.
%
% See also fitcsvm, svmclassify, cosmo_classify_matlabsvm.
%
% #   For CoSMoMVPA's copyright information and license terms,   #
% #   see the COPYING file distributed with CoSMoMVPA.           #

    if nargin<4, opt=struct(); end

    [ntrain, nfeatures]=size(samples_train);
    [ntest, nfeatures_]=size(samples_test);
    ntrain_=numel(targets_train);

    if nfeatures~=nfeatures_ || ntrain_~=ntrain
        error('illegal input size');
    end

    if ~cached_has_matlabcsvm()
        cosmo_check_external('matlabcsvm');
    end

    [class_idxs,classes]=cosmo_index_unique(targets_train(:));
    nclasses=numel(classes);

    if nfeatures==0 || nclasses==1
        % matlab's svm cannot deal with empty data, so predict all
        % test samples as the class of the first sample
        predicted=targets_train(1) * ones(ntest,1);
        return
    end


    opt_cell=opt2cell(opt);

    % number of pair-wise comparisons
    ncombi=nclasses*(nclasses-1)/2;

    % allocate space for all predictions
    all_predicted=NaN(ntest, ncombi);

    % Consider all pairwise comparisons (over classes)
    % and store the predictions in all_predicted
    pos=0;
    for k=1:(nclasses-1)
        for j=(k+1):nclasses
            pos=pos+1;
            % classify between 2 classes only
            idxs=cat(1,class_idxs{k},class_idxs{j});


            model=fitcsvm(samples_train(idxs,:), targets_train(idxs), ...
                                    opt_cell{:});

            pred=predict(model, samples_test(idxs, :));
            all_predicted(idxs,pos)=pred;
        end
    end

    assert(pos==ncombi);

    % find the classes that were predicted most often.
    % ties are handled by cosmo_winner_indices
    [winners, test_classes]=cosmo_winner_indices(all_predicted);

    predicted=test_classes(winners);



    % helper function to convert cell to struct
function opt_cell=opt2cell(opt)

    if isempty(opt)
        opt_cell=cell(0);
        return;
    end

    fns=fieldnames(opt);

    n=numel(fns);
    opt_cell=cell(1,2*n);
    for k=1:n
        fn=fns{keep_id(k)};
        opt_cell{k*2-1}=fn;
        opt_cell{k*2}=opt.(fn);
    end


function tf=cached_has_matlabcsvm()
    persistent cached_tf;

    if isequal(cached_tf,true)
        tf=true;
        return
    end

    cached_tf=cosmo_check_external('matlabcsvm');
    tf=cached_tf;