test searchlight

function test_suite = test_searchlight
% tests for cosmo_searchlight
%
% #   For CoSMoMVPA's copyright information and license terms,   #
% #   see the COPYING file distributed with CoSMoMVPA.           #
    try % assignment of 'localfunctions' is necessary in Matlab >= 2016
        test_functions=localfunctions();
    catch % no problem; early Matlab versions can use initTestSuite fine
    end
    initTestSuite;


function test_searchlight_singlethread()
    opt=struct();
    opt.progress=false;
    helper_test_searchlight(opt);

function test_searchlight_matlab_multithread()
    has_function=@(x)~isempty(which(x));
    has_parallel_toolbox=all(cellfun(has_function, {'gcp','parpool'}));

    if ~has_parallel_toolbox
        cosmo_notify_test_skipped('Matlab parallel toolbox not available');
        return;
    end

    warning_state=cosmo_warning();
    warning_resetter=onCleanup(@()cosmo_warning(warning_state));
    cosmo_warning('off');

    opt=struct();
    opt.progress=false;
    opt.nproc=2;
    helper_test_searchlight(opt);

function test_searchlight_octave_multithread()
    has_function=@(x)~isempty(which(x));
    has_parallel_toolbox=all(cellfun(has_function, {'parcellfun',...
                                                    'pararrayfun'}));

    if ~has_parallel_toolbox
        cosmo_notify_test_skipped('Octave parallel toolbox not available');
        return;
    end

    warning_state=cosmo_warning();
    warning_resetter=onCleanup(@()cosmo_warning(warning_state));
    cosmo_warning('off');

    opt=struct();
    opt.progress=false;
    opt.nproc=2;
    helper_test_searchlight(opt);


function helper_test_searchlight(opt)

    ds=cosmo_synthetic_dataset('size','normal');
    m=any(abs(ds.samples)>3,1);
    ds=cosmo_slice(ds,~m,2);
    ds=cosmo_dim_prune(ds);

    measure=@(x,a) cosmo_structjoin('samples',size(x.samples,2));
    nh=cosmo_spherical_neighborhood(ds,'radius',2,'progress',0);

    m=cosmo_searchlight(ds,nh,measure,opt);

    assertEqual(m.samples,[8 12 10 9 12 10 16 13 12 17 14 15 ...
                            13 11 15 14 10 9 14 11 5 7 6 7]);
    assertEqual(m.fa.i,ds.fa.i);
    assertEqual(m.fa.j,ds.fa.j);
    assertEqual(m.fa.k,ds.fa.k);
    assertEqual(m.a,ds.a);

    nh2=cosmo_spherical_neighborhood(ds,'count',17,'progress',0);
    m=cosmo_searchlight(ds,nh2,measure,opt);
    assertEqual(m.samples,[17 17 17 17 17 17 17 17 17 17 18 16 ...
                                17 17 16 15 17 17 17 17 17 17 17 17]);


    measure=@cosmo_correlation_measure;

    nh3=cosmo_spherical_neighborhood(ds,'radius',2,...
                                cosmo_structjoin('progress',0));
    m=cosmo_searchlight(ds, nh3, measure,...
                            'center_ids',[4 21],opt);

    assertVectorsAlmostEqual(m.samples, [0.9742,-.0273]...
                                        ,'relative',.001);
    assertEqual(m.fa.i,[1 1]);
    assertEqual(m.fa.j,[2 1]);
    assertEqual(m.fa.k,[1 5]);

    sa=struct();
    sa.time=(1:6)';
    sdim=struct();
    sdim.values={10:15};
    sdim.labels={'time'};

    nh4=cosmo_spherical_neighborhood(ds,'radius',0,'progress',false);
    measure2=@(x,opt)cosmo_structjoin('samples',mean(x.samples,2),...
                                       'sa',sa,...
                                       'a',cosmo_structjoin('sdim',sdim));
    m2=cosmo_searchlight(ds,nh4,measure2,opt);
    assertEqual(m2.sa,sa);
    assertEqual(m2.a.sdim,sdim);
    assertElementsAlmostEqual(m2.samples,ds.samples);


function test_searchlight_partial_classification
    ds=cosmo_synthetic_dataset('nchunks',6);

    partitions=struct();
    train_msk=mod(ds.sa.chunks,2)==0;
    partitions.train_indices={find(train_msk)};
    partitions.test_indices={find(~train_msk)};

    measure=@cosmo_crossvalidation_measure;
    opt=struct();
    opt.classifier=@cosmo_classify_lda;
    opt.partitions=partitions;
    opt.progress=false;
    opt.output='winner_predictions';


    nh=cosmo_spherical_neighborhood(ds,'radius',1,'progress',false);
    res=cosmo_searchlight(ds,nh,measure,opt);

    assertEqual(res.samples([3 4 7 8 11 12],:),NaN(6,6))
    s=res.samples([1 2 5 6 9 10],:);
    assertTrue(all(s(:)==1 | s(:)==2));



function test_searchlight_exceptions
    aet=@(varargin)assertExceptionThrown(@()...
                            cosmo_searchlight(varargin{:},...
                            'progress',0),'');
    ds=cosmo_synthetic_dataset();
    nh=cosmo_spherical_neighborhood(ds,'radius',1,'progress',false);
    measure=@(x,opt)cosmo_structjoin('samples',mean(x.samples,2));

    aet(struct,nh,measure);
    aet(ds,ds,measure);
    aet(ds,measure,nh);

    measure_bad=@(x,opt)cosmo_structjoin('samples',mean(x.samples,1));
    aet(ds,nh,measure_bad);



function test_searchlight_progress()
    if cosmo_skip_test_if_no_external('!evalc')
        return;
    end

    ds=cosmo_synthetic_dataset();
    nh=cosmo_spherical_neighborhood(ds,'count',2,'progress',false);
    measure=@(x,opt)cosmo_structjoin('samples',mean(x.samples,2));
    f=@()cosmo_searchlight(ds,nh,measure);
    res=evalc('f();');
    assert(~isempty(strfind(res,'[####################]')));