cosmo classify naive bayes skl

function predicted=cosmo_classify_naive_bayes(samples_train, targets_train, samples_test, unused)
% naive bayes classifier
%
% predicted=cosmo_classify_naive_bayes(samples_train, targets_train, samples_test[, opt])
%
% Inputs:
%   samples_train      PxR training data for P samples and R features
%   targets_train      Px1 training data classes
%   samples_test       QxR test data
%   opt                (currently ignored)
%
% Output:
%   predicted          Qx1 predicted data classes for samples_test
%
% Example:
%     ds=cosmo_synthetic_dataset('ntargets',5,'nchunks',15);
%     test_chunk=1;
%     te=cosmo_slice(ds,ds.sa.chunks==test_chunk);
%     tr=cosmo_slice(ds,ds.sa.chunks~=test_chunk);
%     unused=struct();
%     pred=cosmo_classify_naive_bayes(tr.samples,tr.sa.targets,...
%                                        te.samples,unused);
%     disp([te.sa.targets pred])
%     %||      1     1
%     %||      2     2
%     %||      3     3
%     %||      4     4
%     %||      5     5
%
% See also: cosmo_crossvalidate, cosmo_crossvalidation_measure
%
% #   For CoSMoMVPA's copyright information and license terms,   #
% #   see the COPYING file distributed with CoSMoMVPA.           #

    persistent cached_targets_train;
    persistent cached_samples_train;
    persistent cached_model;

    if isequal(cached_targets_train, targets_train) && ...
            isequal(cached_samples_train, samples_train)
        model=cached_model;
    else
        model=train(samples_train, targets_train);
        cached_targets_train=targets_train;
        cached_samples_train=samples_train;
        cached_model=model;
    end

    predicted=test(model, samples_test);

function predicted=test(model, samples_test)
    mus=model.mus;
    vars=model.vars;
    log_class_probs=model.log_class_probs;
    classes=model.classes;

    [ntest,nfeatures]=size(samples_test);
    if nfeatures~=size(mus,2)
        error('size mismatch');
    end

    predicted=zeros(ntest,1);

    for k=1:ntest
        sample=samples_test(k,:);
        log_ps=log_normal_pdf(sample, mus, vars);

        % make octave more compatible with matlab: convert nan to 1
        log_ps(isnan(log_ps))=1;

        % being 'naive' we assume independence - so take the product of the
        % p values. (for better precision we take the log of the
        % probablities and sum them)
        log_test_prob=sum(log_ps,2)+log_class_probs;

        % find the one with the highest probability
        [unused, mx_idx]=max(log_test_prob);

        predicted(k)=classes(mx_idx);
    end

function ps=log_normal_pdf(xs, mus, vars)
    ps=-.5*(log(2*pi*vars) + bsxfun(@minus,xs,mus).^2./vars);


function model=train(samples_train, targets_train)
    [ntrain,nfeatures]=size(samples_train);
    if ntrain~=numel(targets_train)
        error('size mismatch');
    end

    [class_idxs,classes_cell]=cosmo_index_unique({targets_train});
    classes=classes_cell{1};
    nclasses=numel(classes);

    % allocate space for statistics of each class
    mus=zeros(nclasses,nfeatures);
    vars=zeros(nclasses,nfeatures);
    log_class_probs=zeros(nclasses,1);

    % compute means and standard deviations of each class
    for k=1:nclasses
        idx=class_idxs{k};
        nsamples_in_class=numel(idx); % number of samples
        if nsamples_in_class<2
            error(['Cannot train: class %d has only %d samples, %d '...
                    'are required'],nsamples_in_class,classes(k));
        end

        d=samples_train(idx,:); % samples in this class
        mu=mean(d); %mean
        mus(k,:)=mu;

        % variance - faster than 'var'
        vars(k,:)=sum(bsxfun(@minus,mu,d).^2,1)/nsamples_in_class;

        % log of class probability
        log_class_probs(k)=log(nsamples_in_class/ntrain);
    end

    model.mus=mus;
    model.vars=vars;
    model.log_class_probs=log_class_probs;
    model.classes=classes;