test interval neighborhood

function test_suite = test_interval_neighborhood()
    % tests for cosmo_interval_neighborhood
    %
    % #   For CoSMoMVPA's copyright information and license terms,   #
    % #   see the COPYING file distributed with CoSMoMVPA.           #
    try % assignment of 'localfunctions' is necessary in Matlab >= 2016
        test_functions = localfunctions();
    catch % no problem; early Matlab versions can use initTestSuite fine
    end
    initTestSuite;

function test_interval_neighborhood_basis()
    ds_full = cosmo_synthetic_dataset('type', 'meeg', 'size', 'big');
    ds_full = cosmo_slice(ds_full, ds_full.fa.chan < 3, 2);
    ds_full = cosmo_dim_prune(ds_full);

    sliceargs = {1:7, [1 4 7], [2 6]};
    radii = [0 1 2];
    for k = 1:numel(sliceargs)
        slicearg = sliceargs{k};
        narg = numel(slicearg);

        ds = cosmo_slice(ds_full, cosmo_match(ds_full.fa.time, slicearg), 2);
        ds = cosmo_dim_prune(ds);
        nf = size(ds.samples, 2);
        ds = cosmo_slice(ds, randperm(nf), 2);

        for j = 1:numel(radii)
            ds = cosmo_slice(ds, randperm(nf), 2);
            fa_time = ds.fa.time;

            radius = radii(j);

            nh = cosmo_interval_neighborhood(ds, 'time', 'radius', radius);
            assert(numel(nh.neighbors) == narg);
            assertEqual(nh.fa.time, 1:narg);
            assertEqual(nh.a.fdim.values, ...
                        {ds_full.a.fdim.values{2}(slicearg)});

            assertEqual(nh.origin.a.fdim, ds.a.fdim);
            assertEqual(nh.origin.fa, ds.fa);

            for m = 1:narg
                msk = m - radius <= fa_time & ...
                        fa_time <= m + radius;
                assertEqual(find(msk), nh.neighbors{m});
            end

            % should properly deal with permutations
            ds2 = cosmo_slice(ds, randperm(nf), 2);
            ds2.a.fdim.values = cellfun(@transpose, ds2.a.fdim.values, ...
                                        'UniformOutput', false)';
            nh2 = cosmo_interval_neighborhood(ds2, 'time', 'radius', radius);
            assertEqual(nh.fa, nh2.fa);
            assertEqual(nh.a, nh2.a);
            mp = cosmo_align(ds.fa, ds2.fa);
            for m = 1:numel(nh.neighbors)
                assertEqual(sort(mp(nh2.neighbors{m})), nh.neighbors{m});
            end
        end
    end

    % test exceptionsclc
    aet = @(x, i)assertExceptionThrown(@() ...
                                       cosmo_interval_neighborhood(x{:}), i);

    aet({ds}, '');
    aet({ds, 'time'}, '');
    aet({ds, 'x', 2}, '');
    aet({ds, 'time', -1}, '');
    aet({ds, 'time', [2 3]}, '');
    aet({ds, 'time', 'radius'}, '');
    aet({ds, 'time', 'radius', -1}, '');

function test_interval_neighborhood_sa()
    ds = cosmo_synthetic_dataset('type', 'meeg');
    ds_tr = cosmo_dim_transpose(ds, 'time');
    ds_tr = cosmo_slice(ds_tr, randperm(12));
    for radius = 0:1
        nbrhood = cosmo_interval_neighborhood(ds_tr, 'time', 'radius', radius);
        unq_time = unique(ds_tr.sa.time);
        for k = 1:numel(unq_time)
            msk = abs(ds_tr.sa.time - unq_time(k)) <= radius;
            assertEqual(nbrhood.neighbors{k}, find(msk)');
        end
    end

function test_interval_neighborhood_fa()
    ds = cosmo_synthetic_dataset('type', 'meeg', 'size', 'big');
    nf = size(ds.samples, 2);
    rp = randperm(nf);
    dsp = cosmo_slice(ds, rp, 2);

    for radius = 0:10
        nhp = cosmo_interval_neighborhood(dsp, 'time', 'radius', radius);
        assertEqual(nhp.a.fdim.values{1}, ds.a.fdim.values{2});

        for k = 1:numel(nhp.neighbors)
            idx = find(abs(nhp.fa.time(k) - dsp.fa.time) <= radius);
            assertEqual(nhp.neighbors{k}, idx);
        end
    end

function test_sparse_interval_neighborhood
    ds = cosmo_synthetic_dataset('size', 'big');

    % make some holes
    ds = cosmo_slice(ds, ds.fa.i >= 4 & ds.fa.i <= 16, 2);
    ds = cosmo_slice(ds, mod(ds.fa.i, 3) <= 1, 2);

    for radius = 0:5
        nh = cosmo_interval_neighborhood(ds, 'i', 'radius', radius);

        n_nbrs = numel(nh.neighbors);

        assert(n_nbrs == numel(ds.a.fdim.values{1}));
        for j = 1:n_nbrs
            nbrs = nh.neighbors{j};

            idx = find(j - radius <= ds.fa.i & ds.fa.i <= j + radius);
            assertEqual(nbrs(:), idx(:), sprintf(['not equal with '...
                                                  'radius=%d, index=%d'], ...
                                                 radius, j));
        end
    end