function predicted = cosmo_classify_naive_bayes(samples_train, targets_train, samples_test, unused)
% naive bayes classifier
%
% predicted=cosmo_classify_naive_bayes(samples_train, targets_train, samples_test[, opt])
%
% Inputs:
% samples_train PxR training data for P samples and R features
% targets_train Px1 training data classes
% samples_test QxR test data
% opt (currently ignored)
%
% Output:
% predicted Qx1 predicted data classes for samples_test
%
% Example:
% ds=cosmo_synthetic_dataset('ntargets',5,'nchunks',15);
% test_chunk=1;
% te=cosmo_slice(ds,ds.sa.chunks==test_chunk);
% tr=cosmo_slice(ds,ds.sa.chunks~=test_chunk);
% unused=struct();
% pred=cosmo_classify_naive_bayes(tr.samples,tr.sa.targets,...
% te.samples,unused);
% disp([te.sa.targets pred])
% %|| 1 1
% %|| 2 2
% %|| 3 3
% %|| 4 4
% %|| 5 5
%
% See also: cosmo_crossvalidate, cosmo_crossvalidation_measure
%
% # For CoSMoMVPA's copyright information and license terms, #
% # see the COPYING file distributed with CoSMoMVPA. #
persistent cached_targets_train
persistent cached_samples_train
persistent cached_model
if isequal(cached_targets_train, targets_train) && ...
isequal(cached_samples_train, samples_train)
model = cached_model;
else
model = train(samples_train, targets_train);
cached_targets_train = targets_train;
cached_samples_train = samples_train;
cached_model = model;
end
predicted = test(model, samples_test);
function predicted = test(model, samples_test)
mus = model.mus;
vars = model.vars;
log_class_probs = model.log_class_probs;
classes = model.classes;
[ntest, nfeatures] = size(samples_test);
if nfeatures ~= size(mus, 2)
error('size mismatch');
end
predicted = zeros(ntest, 1);
for k = 1:ntest
sample = samples_test(k, :);
log_ps = log_normal_pdf(sample, mus, vars);
% make octave more compatible with matlab: convert nan to 1
log_ps(isnan(log_ps)) = 1;
% being 'naive' we assume independence - so take the product of the
% p values. (for better precision we take the log of the
% probabilities and sum them)
log_test_prob = sum(log_ps, 2) + log_class_probs;
% find the one with the highest probability
[unused, mx_idx] = max(log_test_prob);
predicted(k) = classes(mx_idx);
end
function ps = log_normal_pdf(xs, mus, vars)
ps = -.5 * (log(2 * pi * vars) + bsxfun(@minus, xs, mus).^2 ./ vars);
function model = train(samples_train, targets_train)
[ntrain, nfeatures] = size(samples_train);
if ntrain ~= numel(targets_train)
error('size mismatch');
end
[class_idxs, classes_cell] = cosmo_index_unique({targets_train});
classes = classes_cell{1};
nclasses = numel(classes);
% allocate space for statistics of each class
mus = zeros(nclasses, nfeatures);
vars = zeros(nclasses, nfeatures);
log_class_probs = zeros(nclasses, 1);
% compute means and standard deviations of each class
for k = 1:nclasses
idx = class_idxs{k};
nsamples_in_class = numel(idx); % number of samples
if nsamples_in_class < 2
error(['Cannot train: class %d has only %d samples, %d '...
'are required'], nsamples_in_class, classes(k));
end
d = samples_train(idx, :); % samples in this class
mu = mean(d); % mean
mus(k, :) = mu;
% variance - faster than 'var'
vars(k, :) = sum(bsxfun(@minus, mu, d).^2, 1) / nsamples_in_class;
% log of class probability
log_class_probs(k) = log(nsamples_in_class / ntrain);
end
model.mus = mus;
model.vars = vars;
model.log_class_probs = log_class_probs;
model.classes = classes;