cosmo classify naive bayes skl

function predicted = cosmo_classify_naive_bayes(samples_train, targets_train, samples_test, unused)
    % naive bayes classifier
    %
    % predicted=cosmo_classify_naive_bayes(samples_train, targets_train, samples_test[, opt])
    %
    % Inputs:
    %   samples_train      PxR training data for P samples and R features
    %   targets_train      Px1 training data classes
    %   samples_test       QxR test data
    %   opt                (currently ignored)
    %
    % Output:
    %   predicted          Qx1 predicted data classes for samples_test
    %
    % Example:
    %     ds=cosmo_synthetic_dataset('ntargets',5,'nchunks',15);
    %     test_chunk=1;
    %     te=cosmo_slice(ds,ds.sa.chunks==test_chunk);
    %     tr=cosmo_slice(ds,ds.sa.chunks~=test_chunk);
    %     unused=struct();
    %     pred=cosmo_classify_naive_bayes(tr.samples,tr.sa.targets,...
    %                                        te.samples,unused);
    %     disp([te.sa.targets pred])
    %     %||      1     1
    %     %||      2     2
    %     %||      3     3
    %     %||      4     4
    %     %||      5     5
    %
    % See also: cosmo_crossvalidate, cosmo_crossvalidation_measure
    %
    % #   For CoSMoMVPA's copyright information and license terms,   #
    % #   see the COPYING file distributed with CoSMoMVPA.           #

    persistent cached_targets_train
    persistent cached_samples_train
    persistent cached_model

    if isequal(cached_targets_train, targets_train) && ...
            isequal(cached_samples_train, samples_train)
        model = cached_model;
    else
        model = train(samples_train, targets_train);
        cached_targets_train = targets_train;
        cached_samples_train = samples_train;
        cached_model = model;
    end

    predicted = test(model, samples_test);

function predicted = test(model, samples_test)
    mus = model.mus;
    vars = model.vars;
    log_class_probs = model.log_class_probs;
    classes = model.classes;

    [ntest, nfeatures] = size(samples_test);
    if nfeatures ~= size(mus, 2)
        error('size mismatch');
    end

    predicted = zeros(ntest, 1);

    for k = 1:ntest
        sample = samples_test(k, :);
        log_ps = log_normal_pdf(sample, mus, vars);

        % make octave more compatible with matlab: convert nan to 1
        log_ps(isnan(log_ps)) = 1;

        % being 'naive' we assume independence - so take the product of the
        % p values. (for better precision we take the log of the
        % probabilities and sum them)
        log_test_prob = sum(log_ps, 2) + log_class_probs;

        % find the one with the highest probability
        [unused, mx_idx] = max(log_test_prob);

        predicted(k) = classes(mx_idx);
    end

function ps = log_normal_pdf(xs, mus, vars)
    ps = -.5 * (log(2 * pi * vars) + bsxfun(@minus, xs, mus).^2 ./ vars);

function model = train(samples_train, targets_train)
    [ntrain, nfeatures] = size(samples_train);
    if ntrain ~= numel(targets_train)
        error('size mismatch');
    end

    [class_idxs, classes_cell] = cosmo_index_unique({targets_train});
    classes = classes_cell{1};
    nclasses = numel(classes);

    % allocate space for statistics of each class
    mus = zeros(nclasses, nfeatures);
    vars = zeros(nclasses, nfeatures);
    log_class_probs = zeros(nclasses, 1);

    % compute means and standard deviations of each class
    for k = 1:nclasses
        idx = class_idxs{k};
        nsamples_in_class = numel(idx); % number of samples
        if nsamples_in_class < 2
            error(['Cannot train: class %d has only %d samples, %d '...
                   'are required'], nsamples_in_class, classes(k));
        end

        d = samples_train(idx, :); % samples in this class
        mu = mean(d); % mean
        mus(k, :) = mu;

        % variance - faster than 'var'
        vars(k, :) = sum(bsxfun(@minus, mu, d).^2, 1) / nsamples_in_class;

        % log of class probability
        log_class_probs(k) = log(nsamples_in_class / ntrain);
    end

    model.mus = mus;
    model.vars = vars;
    model.log_class_probs = log_class_probs;
    model.classes = classes;