test naive bayes classifier searchlight

function test_suite=test_naive_bayes_classifier_searchlight
% tests for cosmo_naive_bayes_classifier_searchlight
%
% #   For CoSMoMVPA's copyright information and license terms,   #
% #   see the COPYING file distributed with CoSMoMVPA.           #
    try % assignment of 'localfunctions' is necessary in Matlab >= 2016
        test_functions=localfunctions();
    catch % no problem; early Matlab versions can use initTestSuite fine
    end
    initTestSuite;

function test_naive_bayes_classifier_searchlight_tiny
    ds=cosmo_synthetic_dataset('ntargets',25,'nchunks',4,'size','small');
    nh=cosmo_spherical_neighborhood(ds,'radius',1,'progress',false);

    nh.neighbors={1:size(ds.samples,2)};
    nh.fa=cosmo_slice(nh.fa,1,2,'struct');

    opt=struct();
    opt.partitions=cosmo_nfold_partitioner(ds);
    opt.output='winner_predictions';
    opt.progress=false;

    x=cosmo_naive_bayes_classifier_searchlight(ds,nh,opt);
    assertEqual(x.fa,nh.fa);
    assertEqual(x.a,nh.a);

    nfolds=numel(opt.partitions.train_indices);
    for fold=1:nfolds
        tr=opt.partitions.train_indices{fold};
        te=opt.partitions.test_indices{fold};


        pred=cosmo_classify_naive_bayes(ds.samples(tr,:),...
                                            ds.sa.targets(tr),...
                                            ds.samples(te,:));
        y=x.samples(te,:);
        assertEqual(y,pred);


    end
    % compare with standard searchlight
    assert_same_output_as_classifical_searchlight(ds,nh,opt);



function test_naive_bayes_classifier_searchlight_basics
    ds=cosmo_synthetic_dataset('ntargets',25,'nchunks',4,'size','small');
    nh=cosmo_spherical_neighborhood(ds,'radius',1,'progress',false);

    nh.neighbors={1:size(ds.samples,2)};
    nh.fa=cosmo_slice(nh.fa,1,2,'struct');


    opt=struct();
    opt.partitions=cosmo_nfold_partitioner(ds);
    opt.output='winner_predictions';
    opt.progress=false;

    x=cosmo_naive_bayes_classifier_searchlight(ds,nh,opt);
    assertEqual(x.fa,nh.fa);
    assertEqual(x.a,nh.a);

    % compare with standard searchlight
    assert_same_output_as_classifical_searchlight(ds,nh,opt);

    opt.output='accuracy';
    xacc=cosmo_naive_bayes_classifier_searchlight(ds,nh,opt);
    assertEqual(xacc.samples,mean(bsxfun(@eq,x.samples,x.sa.targets)));
    assert_same_output_as_classifical_searchlight(ds,nh,opt);


function test_naive_bayes_classifier_searchlight_multiple_pred
    ds=cosmo_synthetic_dataset('ntargets',25,'nchunks',5,'size','small');
    nh=cosmo_spherical_neighborhood(ds,'radius',1,'progress',false);

    opt=struct();
    % multiple predictions per fold
    opt.partitions=cosmo_nchoosek_partitioner(ds,2);
    opt.output='winner_predictions';
    opt.progress=false;

    x=cosmo_naive_bayes_classifier_searchlight(ds,nh,opt);
    assertEqual(x.fa,nh.fa);
    assertEqual(x.a,nh.a);

    % compare with standard searchlight
    assert_same_output_as_classifical_searchlight(ds,nh,opt);

    opt.output='accuracy';
    assert_same_output_as_classifical_searchlight(ds,nh,opt);




function assert_same_output_as_classifical_searchlight(ds,nh,opt)
    opt.progress=false;
    x=cosmo_naive_bayes_classifier_searchlight(ds,nh,opt);

    opt.classifier=@cosmo_classify_naive_bayes;
    y=cosmo_searchlight(ds,nh,@cosmo_crossvalidation_measure,opt);

    sx=x.samples;
    sy=y.samples;
    assertElementsAlmostEqual(sx,sy);

    x=rmfield(x,'samples');
    y=rmfield(y,'samples');
    assertEqual(x,y);


function test_naive_bayes_classifier_searchlight_exceptions
    aet=@(varargin)assertExceptionThrown(@()...
                cosmo_naive_bayes_classifier_searchlight(varargin{:}),'');

    ds=cosmo_synthetic_dataset('size','small','nchunks',4);
    nh=cosmo_spherical_neighborhood(ds,'radius',1,'progress',false);

    opt=struct();
    opt.progress=false;
    opt.partitions=cosmo_nchoosek_partitioner(ds,2);

    opt.output='foo';
    aet(ds,nh,opt);

    % missing samples, so illegal partitions
    opt.output='winner_predictions';
    ds=cosmo_slice(ds,ds.sa.chunks<=2);
    opt.partitions=cosmo_nfold_partitioner(ds);
    aet(ds,nh,opt);

    % unsupported output
    opt.output='fold_predictions';
    ds_bad=ds;
    ds_bad=cosmo_slice(ds_bad,ds_bad.sa.chunks<=2);
    opt.partitions=cosmo_nfold_partitioner(ds_bad);
    aet(ds,nh,opt);


function test_naive_bayes_classifier_searchlight_deprecations
    ds=cosmo_synthetic_dataset('size','small','nchunks',4);
    nh=cosmo_spherical_neighborhood(ds,'radius',1,'progress',false);

    %
    opt=struct();
    opt.progress=false;
    opt.partitions=cosmo_nchoosek_partitioner(ds,2);
    opt.output='predictions';

    orig_warning_state=cosmo_warning();
    cleaner=onCleanup(@()cosmo_warning(orig_warning_state));
    cosmo_warning('reset');
    cosmo_warning('off');

    % no warnings
    w=cosmo_warning();
    assertEqual(numel(w.shown_warnings),0);

    % output='predictions' is deprecated, so expect a warning
    cosmo_naive_bayes_classifier_searchlight(ds,nh,opt);

    w=cosmo_warning();
    assertEqual(numel(w.shown_warnings),1);






function test_naive_bayes_classifier_searchlight_partial_partitions
    nchunks=4;
    ds=cosmo_synthetic_dataset('ntargets',5,'nchunks',nchunks,...
                        'size','small');
    nh=cosmo_spherical_neighborhood(ds,'radius',1,'progress',false);

    partitions=cosmo_nfold_partitioner(ds);
    nsamples=size(ds.samples,1);
    prediction_count=zeros(nsamples,1);
    for k=1:nchunks
        with_missing=partitions.test_indices{k}(2:end);
        partitions.test_indices{k}=with_missing;
        prediction_count(with_missing)=prediction_count(with_missing)+1;
    end

    opt=struct();
    opt.progress=false;
    opt.partitions=partitions;
    opt.output='winner_predictions';

    res=cosmo_naive_bayes_classifier_searchlight(ds,nh,opt);

    ds_sa=rmfield(ds.sa,'chunks');
    assertEqual(res.sa,ds_sa);
    cosmo_check_dataset(res);

    assert_same_output_as_classifical_searchlight(ds,nh,opt);