function result=cosmo_naive_bayes_classifier_searchlight(ds, nbrhood, varargin)
% Run (fast) Naive Bayes classifier searchlight with crossvalidation
%
% result=cosmo_naive_bayes_classifier_searchlight(ds, nbrhood, ...)
%
% Inputs:
% ds dataset struct
% nbrhood Neighborhood structure with fields:
% .a struct with dataset attributes
% .fa struct with feature attributes. Each field
% should have NF values in the second dimension
% .neighbors cell with NF mappings from center_ids in output
% dataset to feature ids in input dataset.
% Suitable neighborhood structs can be generated
% using:
% - cosmo_spherical_neighborhood (fmri volume)
% - cosmo_surficial_neighborhood (fmri surface)
% - cosmo_meeg_chan_neigborhood (MEEG channels)
% - cosmo_interval_neighborhood (MEEG time, freq)
% - cosmo_cross_neighborhood (to cross neighborhoods
% from the neighborhood
% functions above)
% 'partitions', par Partition scheme to use. Typically this is the
% output from cosmo_nfold_partitioner or
% cosmo_oddeven_partitioner. Partitions schemes
% with more than one prediction for samples in the
% test set (such as the output from
% cosmo_nchoosek_partitioner(N) with N>1) are not
% supported
% 'output', out One of:
% 'accuracy' return classification accuracy
% 'predictions' return prediction for each sample
% in a test set in partitions
% 'progress', p Show progress every p folds (default: 1)
%
% Output:
% results_map a dataset struct where the samples
% contain classification accuracies or class
% predictions
%
% Example:
% % generate tiny dataset (6 voxels) and define a tiny spherical
% % neighborhood with a radius of 1 voxel.
% ds=cosmo_synthetic_dataset('nchunks',10,'ntargets',5);
% nh=cosmo_spherical_neighborhood(ds,'radius',1,'progress',false);
% %
% % set options
% opt=struct();
% opt.progress=false;
% % define take-one-chunk-out crossvalidation scheme (10 folds)
% opt.partitions=cosmo_nfold_partitioner(ds);
% %
% % run searchlight
% result=cosmo_naive_bayes_classifier_searchlight(ds,nh,opt);
% %
% % show result
% cosmo_disp(result);
% %|| .a
% %|| .fdim
% %|| .labels
% %|| { 'i' 'j' 'k' }
% %|| .values
% %|| { [ 1 2 3 ] [ 1 2 ] [ 1 ] }
% %|| .vol
% %|| .mat
% %|| [ 2 0 0 -3
% %|| 0 2 0 -3
% %|| 0 0 2 -3
% %|| 0 0 0 1 ]
% %|| .dim
% %|| [ 3 2 1 ]
% %|| .xform
% %|| 'scanner_anat'
% %|| .fa
% %|| .nvoxels
% %|| [ 3 4 3 3 4 3 ]
% %|| .radius
% %|| [ 1 1 1 1 1 1 ]
% %|| .center_ids
% %|| [ 1 2 3 4 5 6 ]
% %|| .i
% %|| [ 1 2 3 1 2 3 ]
% %|| .j
% %|| [ 1 1 1 2 2 2 ]
% %|| .k
% %|| [ 1 1 1 1 1 1 ]
% %|| .samples
% %|| [ 0.6 0.74 0.44 0.6 0.7 0.4 ]
% %|| .sa
% %|| .labels
% %|| { 'accuracy' }
%
%
% Notes:
% - this function runs considerably faster than using a searchlight with
% a classifier function and a crossvalidation scheme, because model
% parameters during training are estimated only once for each feature.
% Thus, speedups are most significant if elements in the neighborhood
% have many overlapping features
% - for other classifiers or other measures, use the more flexible
% cosmo_searchlight function
%
% See also: cosmo_searchlight
%
% # For CoSMoMVPA's copyright information and license terms, #
% # see the COPYING file distributed with CoSMoMVPA. #
% check input
cosmo_check_dataset(ds);
% set defaults
defaults=struct();
defaults.output='accuracy';
defaults.progress=1;
opt=cosmo_structjoin(defaults,varargin{:});
show_progress=isfield(opt,'progress') && opt.progress;
% get neighborhood in matrix form for faster lookups
nbrhood_mat=cosmo_convert_neighborhood(nbrhood,'matrix');
ncenters=size(nbrhood_mat,2);
nsamples=size(ds.samples,1);
% get partitions for crossvalidation
[train_idxs,test_idxs]=get_partitions(ds,opt);
nfolds=numel(train_idxs);
max_prediction_count=get_max_prediction_count(test_idxs);
predictions=NaN(nsamples,ncenters,max_prediction_count);
prediction_count=zeros(nsamples,1);
if show_progress
clock_start=clock();
prev_progress_msg='';
end
% perform classification for each fold
for fold=1:nfolds
train_idx=train_idxs{fold};
test_idx=test_idxs{fold};
samples_train=ds.samples(train_idx,:);
targets_train=ds.sa.targets(train_idx);
% estimate parameters
model=naive_bayes_train(samples_train, targets_train);
% predict classes
test_samples=ds.samples(test_idx,:);
fold_pred=naive_bayes_predict(model, nbrhood_mat, test_samples);
% store predictions; work backwards to ensure each is stored just
% once
for col=max_prediction_count:-1:1
row_msk=prediction_count(test_idx)==(col-1);
row=test_idx(row_msk);
predictions(row,:,col)=fold_pred(row_msk,:);
prediction_count(row)=prediction_count(row)+1;
end
if show_progress && (fold<10 || ~mod(fold,opt.progress) || ...
fold==ncenters)
msg='';
prev_progress_msg=cosmo_show_progress(clock_start, ...
fold/nfolds, msg, prev_progress_msg);
end
end
result=struct();
result.a=nbrhood.a;
result.fa=nbrhood.fa;
% set output
output=opt.output;
switch output
case 'accuracy'
is_pred=~isnan(predictions);
is_correct=is_pred & bsxfun(@eq,predictions,ds.sa.targets);
correct_count=sum(sum(is_correct,1),3);
pred_count=sum(sum(is_pred,1),3);
result.samples=correct_count./pred_count;
result.sa.labels={'accuracy'};
case {'winner_predictions','predictions'}
if cosmo_match({output},{'predictions'})
cosmo_warning('CoSMoMVPA:deprecated',...
sprintf(...
['Output option ''%s'' is deprecated and will '...
'be removed from a future release. Please use '...
'output=''winner_predictions'' instead.'],...
output));
end
if max_prediction_count<=1
winners=predictions;
else
winners=zeros(nsamples,ncenters);
for k=1:ncenters
pred_mat=reshape(predictions(:,k,:),...
nsamples,max_prediction_count);
[idx,cl]=cosmo_winner_indices(pred_mat);
winners(:,k)=cl(idx);
if show_progress && (k<10 || ...
~mod(k,opt.progress) || ...
k==ncenters)
msg='computing winners';
prev_progress_msg=cosmo_show_progress(...
clock_start, ...
k/ncenters, msg, prev_progress_msg);
end
end
end
result.samples=winners;
result.sa=rmfield(ds.sa,'chunks');
otherwise
error(['illegal output ''%s'', must be '...
'''accuracy'' or ''winner_predictions'''], opt.output);
end
cosmo_check_dataset(result);
function [train_idxs,test_idxs]=get_partitions(ds,opt)
if ~isfield(opt,'partitions')
error('the ''partitions'' option is required');
end
partitions=opt.partitions;
cosmo_check_partitions(partitions,ds);
train_idxs=partitions.train_indices;
test_idxs=partitions.test_indices;
function max_prediction_count=get_max_prediction_count(test_idxs)
max_index=max(cellfun(@max,test_idxs));
h=zeros(max_index,1);
nfolds=numel(test_idxs);
for fold=1:nfolds
idx=test_idxs{fold};
h(idx)=h(idx)+1;
end
max_prediction_count=max(h);
function pred=naive_bayes_predict(model,nbrhood_mat,samples_train)
classes=model.classes;
nclasses=numel(classes);
[max_neighbors,ncenters]=size(nbrhood_mat);
[nsamples,nfeatures]=size(samples_train);
if nfeatures~=size(model.mus,2)
error(['size mismatch in number of features between training '...
'and test set']);
end
max_ps=-Inf(nsamples,ncenters);
pred=NaN(nsamples,ncenters);
for j=1:nclasses
mu=model.mus(j,:);
var_=model.vars(j,:);
% nsamples x nfeatures
xs_z=bsxfun(@rdivide,bsxfun(@minus,samples_train,mu).^2,var_);
log_ps=-.5*(bsxfun(@plus,log(2*pi*var_),xs_z)) + ...
model.log_class_probs(j);
log_sum_ps=zeros(nsamples,ncenters);
for k=1:max_neighbors
row_msk=nbrhood_mat(k,:)>0;
log_sum_ps(:,row_msk)=log_sum_ps(:,row_msk)+...
log_ps(:,nbrhood_mat(k,row_msk));
end
greater_ps_mask=log_sum_ps>max_ps;
max_ps(greater_ps_mask)=log_sum_ps(greater_ps_mask);
pred(greater_ps_mask)=classes(j);
end
function model=naive_bayes_train(samples_train, targets_train)
[ntrain,nfeatures]=size(samples_train);
if ntrain~=numel(targets_train)
error('size mismatch between samples and targets');
end
[class_idxs,classes_cell]=cosmo_index_unique({targets_train});
classes=classes_cell{1};
nclasses=numel(classes);
% allocate space for statistics of each class
mus=zeros(nclasses,nfeatures);
vars=zeros(nclasses,nfeatures);
log_class_probs=zeros(nclasses,1);
% compute means and standard deviations of each class
for k=1:nclasses
idx=class_idxs{k};
nsamples_in_class=numel(idx); % number of samples
if nsamples_in_class<2
error(['Cannot train: class %d has only %d samples, %d '...
'are required'],nsamples_in_class,classes(k));
end
d=samples_train(idx,:); % samples in this class
mu=mean(d); %mean
mus(k,:)=mu;
% variance - faster than 'var'
vars(k,:)=sum(bsxfun(@minus,mu,d).^2,1)/nsamples_in_class;
% log of class probability
log_class_probs(k)=log(nsamples_in_class/ntrain);
end
model.mus=mus;
model.vars=vars;
model.log_class_probs=log_class_probs;
model.classes=classes;