cosmo naive bayes classifier searchlight skl

function result=cosmo_naive_bayes_classifier_searchlight(ds, nbrhood, varargin)
% Run (fast) Naive Bayes classifier searchlight with crossvalidation
%
% result=cosmo_naive_bayes_classifier_searchlight(ds, nbrhood, ...)
%
% Inputs:
%   ds                   dataset struct
%   nbrhood              Neighborhood structure with fields:
%         .a               struct with dataset attributes
%         .fa              struct with feature attributes. Each field
%                            should have NF values in the second dimension
%         .neighbors       cell with NF mappings from center_ids in output
%                        dataset to feature ids in input dataset.
%                        Suitable neighborhood structs can be generated
%                        using:
%                        - cosmo_spherical_neighborhood (fmri volume)
%                        - cosmo_surficial_neighborhood (fmri surface)
%                        - cosmo_meeg_chan_neigborhood (MEEG channels)
%                        - cosmo_interval_neighborhood (MEEG time, freq)
%                        - cosmo_cross_neighborhood (to cross neighborhoods
%                                                    from the neighborhood
%                                                    functions above)
%   'partitions', par    Partition scheme to use. Typically this is the
%                        output from cosmo_nfold_partitioner or
%                        cosmo_oddeven_partitioner. Partitions schemes
%                        with more than one prediction for samples in the
%                        test set (such as the output from
%                        cosmo_nchoosek_partitioner(N) with N>1) are not
%                        supported
%   'output', out        One of:
%                        'accuracy'      return classification accuracy
%                        'predictions'   return prediction for each sample
%                                        in a test set in partitions
%   'progress', p        Show progress every p folds (default: 1)
%
% Output:
%   results_map          a dataset struct where the samples
%                        contain classification accuracies or class
%                        predictions
%
% Example:
%     % generate tiny dataset (6 voxels) and define a tiny spherical
%     % neighborhood with a radius of 1 voxel.
%     ds=cosmo_synthetic_dataset('nchunks',10,'ntargets',5);
%     nh=cosmo_spherical_neighborhood(ds,'radius',1,'progress',false);
%     %
%     % set options
%     opt=struct();
%     opt.progress=false;
%     % define take-one-chunk-out crossvalidation scheme (10 folds)
%     opt.partitions=cosmo_nfold_partitioner(ds);
%     %
%     % run searchlight
%     result=cosmo_naive_bayes_classifier_searchlight(ds,nh,opt);
%     %
%     % show result
%     cosmo_disp(result);
%     %|| .a
%     %||   .fdim
%     %||     .labels
%     %||       { 'i'  'j'  'k' }
%     %||     .values
%     %||       { [ 1         2         3 ]  [ 1         2 ]  [ 1 ] }
%     %||   .vol
%     %||     .mat
%     %||       [ 2         0         0        -3
%     %||         0         2         0        -3
%     %||         0         0         2        -3
%     %||         0         0         0         1 ]
%     %||     .dim
%     %||       [ 3         2         1 ]
%     %||     .xform
%     %||       'scanner_anat'
%     %|| .fa
%     %||   .nvoxels
%     %||     [ 3         4         3         3         4         3 ]
%     %||   .radius
%     %||     [ 1         1         1         1         1         1 ]
%     %||   .center_ids
%     %||     [ 1         2         3         4         5         6 ]
%     %||   .i
%     %||     [ 1         2         3         1         2         3 ]
%     %||   .j
%     %||     [ 1         1         1         2         2         2 ]
%     %||   .k
%     %||     [ 1         1         1         1         1         1 ]
%     %|| .samples
%     %||   [ 0.6      0.74      0.44       0.6       0.7       0.4 ]
%     %|| .sa
%     %||   .labels
%     %||     { 'accuracy' }
%
%
% Notes:
%   - this function runs considerably faster than using a searchlight with
%     a classifier function and a crossvalidation scheme, because model
%     parameters during training are estimated only once for each feature.
%     Thus, speedups are most significant if elements in the neighborhood
%     have many overlapping features
%   - for other classifiers or other measures, use the more flexible
%     cosmo_searchlight function
%
% See also: cosmo_searchlight
%
% #   For CoSMoMVPA's copyright information and license terms,   #
% #   see the COPYING file distributed with CoSMoMVPA.           #

    % check input
    cosmo_check_dataset(ds);

    % set defaults
    defaults=struct();
    defaults.output='accuracy';
    defaults.progress=1;
    opt=cosmo_structjoin(defaults,varargin{:});
    show_progress=isfield(opt,'progress') && opt.progress;

    % get neighborhood in matrix form for faster lookups
    nbrhood_mat=cosmo_convert_neighborhood(nbrhood,'matrix');

    ncenters=size(nbrhood_mat,2);
    nsamples=size(ds.samples,1);

    % get partitions for crossvalidation
    [train_idxs,test_idxs]=get_partitions(ds,opt);
    nfolds=numel(train_idxs);

    max_prediction_count=get_max_prediction_count(test_idxs);
    predictions=NaN(nsamples,ncenters,max_prediction_count);
    prediction_count=zeros(nsamples,1);

    if show_progress
        clock_start=clock();
        prev_progress_msg='';
    end

    % perform classification for each fold
    for fold=1:nfolds
        train_idx=train_idxs{fold};
        test_idx=test_idxs{fold};
        samples_train=ds.samples(train_idx,:);
        targets_train=ds.sa.targets(train_idx);

        % estimate parameters
        model=naive_bayes_train(samples_train, targets_train);

        % predict classes
        test_samples=ds.samples(test_idx,:);
        fold_pred=naive_bayes_predict(model, nbrhood_mat, test_samples);

        % store predictions; work backwards to ensure each is stored just
        % once
        for col=max_prediction_count:-1:1
            row_msk=prediction_count(test_idx)==(col-1);
            row=test_idx(row_msk);
            predictions(row,:,col)=fold_pred(row_msk,:);
            prediction_count(row)=prediction_count(row)+1;
        end

        if show_progress && (fold<10 || ~mod(fold,opt.progress) || ...
                                                fold==ncenters)
            msg='';
            prev_progress_msg=cosmo_show_progress(clock_start, ...
                            fold/nfolds, msg, prev_progress_msg);
        end

    end

    result=struct();
    result.a=nbrhood.a;
    result.fa=nbrhood.fa;

    % set output
    output=opt.output;
    switch output
        case 'accuracy'
            is_pred=~isnan(predictions);
            is_correct=is_pred & bsxfun(@eq,predictions,ds.sa.targets);

            correct_count=sum(sum(is_correct,1),3);
            pred_count=sum(sum(is_pred,1),3);

            result.samples=correct_count./pred_count;
            result.sa.labels={'accuracy'};

        case {'winner_predictions','predictions'}
            if cosmo_match({output},{'predictions'})
                cosmo_warning('CoSMoMVPA:deprecated',...
                        sprintf(...
                        ['Output option ''%s'' is deprecated and will '...
                        'be removed from a future release. Please use '...
                        'output=''winner_predictions'' instead.'],...
                            output));
            end


            if max_prediction_count<=1
                winners=predictions;
            else
                winners=zeros(nsamples,ncenters);
                for k=1:ncenters
                    pred_mat=reshape(predictions(:,k,:),...
                                        nsamples,max_prediction_count);
                    [idx,cl]=cosmo_winner_indices(pred_mat);
                    winners(:,k)=cl(idx);

                    if show_progress && (k<10 || ...
                                        ~mod(k,opt.progress) || ...
                                        k==ncenters)
                        msg='computing winners';
                        prev_progress_msg=cosmo_show_progress(...
                                    clock_start, ...
                                    k/ncenters, msg, prev_progress_msg);
                    end
                end
            end

            result.samples=winners;
            result.sa=rmfield(ds.sa,'chunks');

        otherwise
            error(['illegal output ''%s'', must be '...
                    '''accuracy'' or ''winner_predictions'''], opt.output);
    end

    cosmo_check_dataset(result);




function [train_idxs,test_idxs]=get_partitions(ds,opt)
    if ~isfield(opt,'partitions')
        error('the ''partitions'' option is required');
    end
    partitions=opt.partitions;
    cosmo_check_partitions(partitions,ds);

    train_idxs=partitions.train_indices;
    test_idxs=partitions.test_indices;


function max_prediction_count=get_max_prediction_count(test_idxs)
    max_index=max(cellfun(@max,test_idxs));
    h=zeros(max_index,1);

    nfolds=numel(test_idxs);
    for fold=1:nfolds
        idx=test_idxs{fold};
        h(idx)=h(idx)+1;
    end

    max_prediction_count=max(h);



function pred=naive_bayes_predict(model,nbrhood_mat,samples_train)
    classes=model.classes;
    nclasses=numel(classes);

    [max_neighbors,ncenters]=size(nbrhood_mat);
    [nsamples,nfeatures]=size(samples_train);

    if nfeatures~=size(model.mus,2)
        error(['size mismatch in number of features between training '...
                'and test set']);
    end


    max_ps=-Inf(nsamples,ncenters);
    pred=NaN(nsamples,ncenters);


    for j=1:nclasses
        mu=model.mus(j,:);
        var_=model.vars(j,:);

        % nsamples x nfeatures
        xs_z=bsxfun(@rdivide,bsxfun(@minus,samples_train,mu).^2,var_);
        log_ps=-.5*(bsxfun(@plus,log(2*pi*var_),xs_z)) + ...
                model.log_class_probs(j);
        log_sum_ps=zeros(nsamples,ncenters);
        for k=1:max_neighbors
            row_msk=nbrhood_mat(k,:)>0;
            log_sum_ps(:,row_msk)=log_sum_ps(:,row_msk)+...
                                    log_ps(:,nbrhood_mat(k,row_msk));
        end

        greater_ps_mask=log_sum_ps>max_ps;
        max_ps(greater_ps_mask)=log_sum_ps(greater_ps_mask);
        pred(greater_ps_mask)=classes(j);
    end



function model=naive_bayes_train(samples_train, targets_train)
    [ntrain,nfeatures]=size(samples_train);
    if ntrain~=numel(targets_train)
        error('size mismatch between samples and targets');
    end

    [class_idxs,classes_cell]=cosmo_index_unique({targets_train});
    classes=classes_cell{1};
    nclasses=numel(classes);

    % allocate space for statistics of each class
    mus=zeros(nclasses,nfeatures);
    vars=zeros(nclasses,nfeatures);
    log_class_probs=zeros(nclasses,1);

    % compute means and standard deviations of each class
    for k=1:nclasses
        idx=class_idxs{k};
        nsamples_in_class=numel(idx); % number of samples
        if nsamples_in_class<2
            error(['Cannot train: class %d has only %d samples, %d '...
                    'are required'],nsamples_in_class,classes(k));
        end

        d=samples_train(idx,:); % samples in this class
        mu=mean(d); %mean
        mus(k,:)=mu;

        % variance - faster than 'var'
        vars(k,:)=sum(bsxfun(@minus,mu,d).^2,1)/nsamples_in_class;

        % log of class probability
        log_class_probs(k)=log(nsamples_in_class/ntrain);
    end

    model.mus=mus;
    model.vars=vars;
    model.log_class_probs=log_class_probs;
    model.classes=classes;