test searchlight

function test_suite = test_searchlight
    % tests for cosmo_searchlight
    %
    % #   For CoSMoMVPA's copyright information and license terms,   #
    % #   see the COPYING file distributed with CoSMoMVPA.           #
    try % assignment of 'localfunctions' is necessary in Matlab >= 2016
        test_functions = localfunctions();
    catch % no problem; early Matlab versions can use initTestSuite fine
    end
    initTestSuite;

function test_searchlight_singlethread()
    opt = struct();
    opt.progress = false;
    helper_test_searchlight(opt);

function test_searchlight_matlab_multithread()
    has_function = @(x)~isempty(which(x));
    has_parallel_toolbox = all(cellfun(has_function, {'gcp', 'parpool'}));

    if ~has_parallel_toolbox
        cosmo_notify_test_skipped('Matlab parallel toolbox not available');
        return
    end

    warning_state = cosmo_warning();
    warning_resetter = onCleanup(@()cosmo_warning(warning_state));
    cosmo_warning('off');

    opt = struct();
    opt.progress = false;
    opt.nproc = 2;
    helper_test_searchlight(opt);

function test_searchlight_octave_multithread()
    has_function = @(x)~isempty(which(x));
    has_parallel_toolbox = all(cellfun(has_function, {'parcellfun', ...
                                                      'pararrayfun'}));

    if ~has_parallel_toolbox
        cosmo_notify_test_skipped('Octave parallel toolbox not available');
        return
    end

    warning_state = cosmo_warning();
    warning_resetter = onCleanup(@()cosmo_warning(warning_state));
    cosmo_warning('off');

    opt = struct();
    opt.progress = false;
    opt.nproc = 2;
    helper_test_searchlight(opt);

function helper_test_searchlight(opt)

    ds = cosmo_synthetic_dataset('size', 'normal');
    m = any(abs(ds.samples) > 3, 1);
    ds = cosmo_slice(ds, ~m, 2);
    ds = cosmo_dim_prune(ds);

    measure = @(x, a) cosmo_structjoin('samples', size(x.samples, 2));
    nh = cosmo_spherical_neighborhood(ds, 'radius', 2, 'progress', 0);

    m = cosmo_searchlight(ds, nh, measure, opt);

    assertEqual(m.samples, [8 12 10 9 12 10 16 13 12 17 14 15 ...
                            13 11 15 14 10 9 14 11 5 7 6 7]);
    assertEqual(m.fa.i, ds.fa.i);
    assertEqual(m.fa.j, ds.fa.j);
    assertEqual(m.fa.k, ds.fa.k);
    assertEqual(m.a, ds.a);

    nh2 = cosmo_spherical_neighborhood(ds, 'count', 17, 'progress', 0);
    m = cosmo_searchlight(ds, nh2, measure, opt);
    assertEqual(m.samples, [17 17 17 17 17 17 17 17 17 17 18 16 ...
                            17 17 16 15 17 17 17 17 17 17 17 17]);

    measure = @cosmo_correlation_measure;

    nh3 = cosmo_spherical_neighborhood(ds, 'radius', 2, ...
                                       cosmo_structjoin('progress', 0));
    m = cosmo_searchlight(ds, nh3, measure, ...
                          'center_ids', [4 21], opt);

    assertVectorsAlmostEqual(m.samples, [0.9742, -.0273] ...
                             , 'relative', .001);
    assertEqual(m.fa.i, [1 1]);
    assertEqual(m.fa.j, [2 1]);
    assertEqual(m.fa.k, [1 5]);

    sa = struct();
    sa.time = (1:6)';
    sdim = struct();
    sdim.values = {10:15};
    sdim.labels = {'time'};

    nh4 = cosmo_spherical_neighborhood(ds, 'radius', 0, 'progress', false);
    measure2 = @(x, opt)cosmo_structjoin('samples', mean(x.samples, 2), ...
                                         'sa', sa, ...
                                         'a', cosmo_structjoin('sdim', sdim));
    m2 = cosmo_searchlight(ds, nh4, measure2, opt);
    assertEqual(m2.sa, sa);
    assertEqual(m2.a.sdim, sdim);
    assertElementsAlmostEqual(m2.samples, ds.samples);

function test_searchlight_partial_classification
    ds = cosmo_synthetic_dataset('nchunks', 6);

    partitions = struct();
    train_msk = mod(ds.sa.chunks, 2) == 0;
    partitions.train_indices = {find(train_msk)};
    partitions.test_indices = {find(~train_msk)};

    measure = @cosmo_crossvalidation_measure;
    opt = struct();
    opt.classifier = @cosmo_classify_lda;
    opt.partitions = partitions;
    opt.progress = false;
    opt.output = 'winner_predictions';

    nh = cosmo_spherical_neighborhood(ds, 'radius', 1, 'progress', false);
    res = cosmo_searchlight(ds, nh, measure, opt);

    assertEqual(res.samples([3 4 7 8 11 12], :), NaN(6, 6));
    s = res.samples([1 2 5 6 9 10], :);
    assertTrue(all(s(:) == 1 | s(:) == 2));

function test_searchlight_exceptions
    aet = @(varargin)assertExceptionThrown(@() ...
                                           cosmo_searchlight(varargin{:}, ...
                                                             'progress', 0), '');
    ds = cosmo_synthetic_dataset();
    nh = cosmo_spherical_neighborhood(ds, 'radius', 1, 'progress', false);
    measure = @(x, opt)cosmo_structjoin('samples', mean(x.samples, 2));

    aet(struct, nh, measure);
    aet(ds, ds, measure);
    aet(ds, measure, nh);

    measure_bad = @(x, opt)cosmo_structjoin('samples', mean(x.samples, 1));
    aet(ds, nh, measure_bad);

function test_searchlight_progress()
    if cosmo_skip_test_if_no_external('!evalc')
        return
    end

    ds = cosmo_synthetic_dataset();
    nh = cosmo_spherical_neighborhood(ds, 'count', 2, 'progress', false);
    measure = @(x, opt)cosmo_structjoin('samples', mean(x.samples, 2));
    f = @()cosmo_searchlight(ds, nh, measure);
    res = evalc('f();');
    assert(~isempty(strfind(res, '[####################]')));