test cluster neighborhood

function test_suite = test_cluster_neighborhood
    % tests for cosmo_cluster_neighborhood
    %
    % #   For CoSMoMVPA's copyright information and license terms,   #
    % #   see the COPYING file distributed with CoSMoMVPA.           #
    try % assignment of 'localfunctions' is necessary in Matlab >= 2016
        test_functions = localfunctions();
    catch % no problem; early Matlab versions can use initTestSuite fine
    end
    initTestSuite;

function test_fmri_cluster_neighborhood

    ds = cosmo_synthetic_dataset('type', 'fmri', 'size', 'normal');
    nf = size(ds.samples, 2);
    imsk = find(rand(1, nf) > .8);
    rp = randperm(numel(imsk));
    ds = cosmo_slice(ds, [imsk(rp) imsk(rp(end:-1:1))], 2);

    nh1 = cosmo_cluster_neighborhood(ds, 'progress', false);
    nh2 = cosmo_cluster_neighborhood(ds, 'progress', false, 'fmri', 1);
    nh_sph = cosmo_spherical_neighborhood(ds, 'progress', false, ...
                                          'radius', 2.5);
    nh3 = cosmo_cluster_neighborhood(ds, 'progress', false, 'fmri', nh_sph);

    ds.fa.sizes = ones(1, size(ds.samples, 2));
    assertEqual(nh1.a, ds.a);
    assertEqual(nh1.fa, ds.fa);

    nf = size(ds.samples, 2);
    rp = randperm(nf);

    nfeatures_test = nf;

    ijk = [ds.fa.i; ds.fa.j; ds.fa.k];
    feature_ids = rp(1:nfeatures_test);

    for j = 1:nfeatures_test
        feature_id = feature_ids(j);
        delta = sqrt(sum(bsxfun(@minus, ijk(:, feature_id), ijk).^2, 1));
        msk1 = delta < sqrt(3) + .001;
        assertEqual(sort(nh1.neighbors{feature_id}), find(msk1));

        msk2 = delta < sqrt(1) + .001;
        assertEqual(sort(nh2.neighbors{feature_id}), find(msk2));

        msk3 = delta < 2.5;
        assertEqual(sort(nh3.neighbors{feature_id}), find(msk3));
    end

function test_tiny_fmri_neighborhood
    ds = cosmo_synthetic_dataset('type', 'fmri', 'size', 'normal');
    ds = cosmo_slice(ds, [22 22], 2);

    % should pass
    nh = cosmo_cluster_neighborhood(ds, 'progress', false);

function test_meeg_cluster_neighborhood
    if cosmo_skip_test_if_no_external('fieldtrip')
        return
    end
    ds = cosmo_synthetic_dataset('type', 'timefreq', 'size', 'normal');
    nf = size(ds.samples, 2);
    imsk = find(rand(1, nf) > .4);
    rp = randperm(numel(imsk));
    ds = cosmo_slice(ds, [imsk(rp) imsk(rp(end:-1:1))], 2);
    ds = cosmo_dim_prune(ds);
    n = numel(ds.a.fdim.values{1});
    ds.a.fdim.values{1} = ds.a.fdim.values{1}(randperm(n));
    nf = size(ds.samples, 2);
    ds.fa.sizes = ones(1, size(ds.samples, 2));

    chan_nbrhood = cosmo_meeg_chan_neighborhood(ds, 'delaunay', true, ...
                                                'chantype', 'all', ...
                                                'label', 'dataset');

    assertEqual(ds.a.fdim.values(1), chan_nbrhood.a.fdim.values(1));

    ncombi = 7; % all 7 possibilities
    test_range = ceil(rand(1, ncombi) * 7);
    nfeatures_test = 3;

    for i = test_range
        use_chan = i <= 4;
        use_freq = mod(i, 2) == 1;
        use_time = mod(ceil(i / 2), 2) == 1;

        use_msk = [use_chan, use_freq, use_time];
        labels = {'chan', 'freq', 'time'};
        ndim = numel(labels);

        args = struct();
        for k = 1:ndim
            if ~use_msk(k)
                label = labels{k};
                args.(label) = false;
            end
        end

        cl_nbrhood = cosmo_cluster_neighborhood(ds, args, 'progress', false);
        assertEqual(cl_nbrhood.fa, ds.fa);
        assertEqual(cl_nbrhood.a, ds.a);

        rp = randperm(nf);
        for k = 1:nfeatures_test
            feature_id = rp(k);

            counter = zeros(1, nf);

            for j = 1:ndim
                label = labels{j};
                fa = ds.fa.(label);
                if use_msk(j)
                    if j == 1
                        % channel
                        ids = chan_nbrhood.neighbors{fa(feature_id)};
                    else
                        % anything else
                        ids = find(abs(fa - fa(feature_id)) <= 1.5);
                    end
                else
                    ids = find(fa == fa(feature_id));
                end

                counter(ids) = counter(ids) + 1;
            end

            nbrs = find(counter == ndim);
            assertEqual(nbrs, cl_nbrhood.neighbors{feature_id});

        end
    end

function test_tiny_meeg_cluster_neighborhood
    if cosmo_skip_test_if_no_external('fieldtrip')
        return
    end
    ds = cosmo_synthetic_dataset('type', 'meeg');
    nh = cosmo_cluster_neighborhood(ds, 'progress', false);

    ds.fa.sizes = ones(1, 6);
    assertEqual(ds.fa, nh.fa);

    assertEqual(ds.a, nh.a);

    half_neighbors = {[1 4]; [2 3 5 6]; [2 3 5 6]};
    assertEqual(nh.neighbors, repmat(half_neighbors, 2, 1));

    % transpose of fdim elements should be fine
    ds2 = ds;
    ds2.a.fdim.values = {ds2.a.fdim.values{1}', ds2.a.fdim.values{2}'};
    nh2 = cosmo_cluster_neighborhood(ds2, 'progress', false);
    cosmo_check_neighborhood(nh2, ds);
    cosmo_check_neighborhood(nh, ds);
    assertEqual(nh.a, nh2.a);
    assertEqual(nh.fa, nh2.fa);
    assertEqual(nh.neighbors, nh2.neighbors);

    % different fdim elements i not ok
    ds3 = ds2;
    ds3.a.fdim.values{1} = ds3.a.fdim.values{1}(end:-1:1);

function test_cluster_neighborhood_transpose
    opt = struct();
    opt.progress = false;
    ds = cosmo_synthetic_dataset('type', 'timefreq', 'size', 'normal');
    ds = cosmo_dim_remove(ds, 'chan');

    nh = cosmo_cluster_neighborhood(ds, opt);

    cp = cosmo_cartprod(repmat({[false, true]}, 4, 1));
    n = size(cp, 1);

    for k = 1:n
        t_label = cp{k, 1};
        t_value = cp{k, 2};
        t_elem1 = cp{k, 3};
        t_elem2 = cp{k, 4};

        ds2 = ds;

        if t_label
            ds2.a.fdim.labels = ds2.a.fdim.labels';
        end

        if t_value
            ds2.a.fdim.values = ds2.a.fdim.values';
        end

        if t_elem1
            ds2.a.fdim.values{1} = ds2.a.fdim.values{1}';
        end

        if t_elem2
            ds2.a.fdim.values{2} = ds2.a.fdim.values{2}';
        end

        nh2 = cosmo_cluster_neighborhood(ds2, opt);
        assertEqual(nh2.a, nh.a);
        assertEqual(nh2.fa, nh.fa);
        assertEqual(nh2.neighbors, nh.neighbors);
    end

function test_cluster_neighborhood_surface
    if cosmo_skip_test_if_no_external('surfing')
        return
    end
    ds = cosmo_synthetic_dataset('type', 'surface'); % ,'size','normal');

    vertices = [0 0 0 1 1 1
                1 2 3 1 2 3
                0 0 0 0 0 0]';
    faces = [3 2 3 2
             2 1 5 4
             5 4 6 5]';

    opt = struct();
    opt.progress = false;
    opt.vertices = vertices;
    opt.faces = faces;

    nh1 = cosmo_cluster_neighborhood(ds, opt);
    assertEqual(nh1.neighbors, { [1 2 4]
                                [1 2 3 4 5]
                                [2 3 5 6]
                                [1 2 4 5]
                                [2 3 4 5 6]
                                [3 5 6] });
    assertEqual(ds.a, nh1.a);
    ds.fa.sizes = [1 3 2 2 3 1] / 6;
    ds.fa.radius = sqrt([1 2 2 2 2 1]);
    ds.fa.area = [6 11 9 9 11 6] / 6;
    assertEqual(ds.fa, nh1.fa);

    opt.direct = false;
    nh2 = cosmo_cluster_neighborhood(ds, opt);
    assertEqual(nh2.neighbors, num2cell((1:6)'));
    assertEqual(ds.a, nh2.a);
    ds.fa.radius(:) = 0;
    ds.fa.area = [1 3 2 2 3 1] / 6;
    assertEqual(ds.fa, nh2.fa);

    % take a subset of features
    rp = randperm(6);
    nsel = 4;
    sel = rp(1:nsel);

    ds3 = cosmo_slice(ds, sel, 2);

    % re-order node ids
    rp2 = randperm(6);
    ds.a.fdim.values{1} = ds.a.fdim.values{1}(rp2);

    opt.direct = true;
    nh3 = cosmo_cluster_neighborhood(ds3, opt);
    assertEqual(nh3.a, ds3.a);
    assertEqual(nh3.fa.node_indices, ds3.fa.node_indices);

    node_ids = ds3.a.fdim.values{1}(ds3.fa.node_indices);
    for k = 1:nsel
        nbs = nh3.neighbors{k}(:);
        node_id = node_ids(k);

        nb1 = intersect(nh1.neighbors{node_id}, sel);
        assertEqual(sort(nb1'), sort(node_ids(nbs))');
    end

function test_cluster_neighborhood_source
    ds = cosmo_synthetic_dataset('size', 'normal', 'type', 'source');
    nf = size(ds.samples, 2);
    [unused, idxs] = sort(cosmo_rand(1, nf * 3));
    rps = mod(idxs - 1, nf) + 1;
    rp = rps(round(nf / 2) + (1:(2 * nf)));
    ds = cosmo_slice(ds, rp, 2);

    [unused, rp] = sort(cosmo_rand(1, nf));
    rp = rp(1:4);

    grid_spacing = 10;
    ds_pos = ds.a.fdim.values{1}(:, ds.fa.pos) / grid_spacing;

    for connectivity = 0:3
        if connectivity == 0
            radius = sqrt(3) + .001;
            args = {};
        else
            args = {'source', connectivity};
            radius = sqrt(connectivity) + .001;
        end

        nh = cosmo_cluster_neighborhood(ds, 'progress', false, args);
        nh_pos = nh.a.fdim.values{1}(:, nh.fa.pos) / grid_spacing;

        for r = rp
            idxs = nh.neighbors{r};

            d = sum(bsxfun(@minus, nh_pos(:, r), ds_pos).^2, 1).^.5;

            d_inside = d(idxs);
            outside_mask = true(size(d));
            outside_mask(d <= radius) = false;
            d_outside = d(outside_mask);
            assert(all(d_inside <= radius));
            assert(all(d_outside > radius));
        end
    end

function test_cluster_neighborhood_source_mom
    ds = cosmo_synthetic_dataset('size', 'normal', 'type', 'source', ...
                                 'data_field', 'mom');
    nf = size(ds.samples, 2);
    [unused, idxs] = sort(cosmo_rand(1, nf * 3));
    rps = mod(idxs - 1, nf) + 1;
    rp = rps(round(nf / 2) + (1:(2 * nf)));
    ds = cosmo_slice(ds, rp, 2);

    grid_spacing = 10;

    for connectivity = 0:3
        if connectivity == 0
            radius = sqrt(3) + .001;
            args = {};
        else
            args = {'source', connectivity};
            radius = sqrt(connectivity) + .001;
        end

        for has_3d_mom = [false true]
            ds2 = ds;
            if has_3d_mom
                assertExceptionThrown(@() ...
                                      cosmo_cluster_neighborhood(ds, ...
                                                                 'progress', false, args), '');
                continue
            end

            keep_dim = ceil(rand() * 3);

            keep_msk = ds2.fa.mom == keep_dim;

            ds2 = cosmo_slice(ds, keep_msk, 2);
            ds2.fa.mom(:) = 1;
            ds2.a.fdim.values{2} = ds2.a.fdim.values{2}(keep_dim);

            ds2_pos = ds2.a.fdim.values{1}(:, ds2.fa.pos) / grid_spacing;

            nf = size(ds2.samples, 2);
            [unused, rp] = sort(cosmo_rand(1, nf));
            rp = rp(1:4);

            nh = cosmo_cluster_neighborhood(ds2, 'progress', false, args);
            nh_pos = nh.a.fdim.values{1}(:, nh.fa.pos) / grid_spacing;

            for r = rp
                idxs = nh.neighbors{r};

                d = sum(bsxfun(@minus, nh_pos(:, r), ...
                               ds2_pos).^2, 1).^.5;

                d_inside = d(idxs);
                outside_mask = true(size(d));
                outside_mask(d <= radius) = false;
                d_outside = d(outside_mask);
                assert(all(d_inside <= radius));
                assert(all(d_outside > radius));
            end
        end
    end

function test_cluster_neighborhood_fmri_time
    for fa_n_rep = [1, 4]
        n_time = ceil(rand() * 3 + 4);
        TR = 2;
        time_values = TR * (1:n_time);

        ds_cell = cell(n_time, fa_n_rep);
        for fa_t = 1:n_time
            ds = cosmo_synthetic_dataset('size', 'normal', 'seed', 0);
            n_features = size(ds.samples, 2);
            time_idx = ones(1, n_features) * fa_t;
            ds = cosmo_dim_insert(ds, 2, 4, {'time'}, {time_values}, {time_idx});
            ds_cell(fa_t, :) = repmat({ds}, 1, fa_n_rep);
        end

        % select subset
        ds_full = cosmo_stack(ds_cell, 2);
        n_features = size(ds_full.samples, 2);
        rp = randperm(n_features);
        n_keep = ceil(n_features / 2);
        keep_idxs = rp(1:n_keep);

        ds = cosmo_slice(ds_full, keep_idxs, 2);

        fa_ijk = [ds.fa.i; ds.fa.j; ds.fa.k];
        fa_t = ds.fa.time;

        for join_time = -1:1
            if join_time == -1
                args = {};
            else
                args = {'time', join_time == 1};
            end
            nh = cosmo_cluster_neighborhood(ds, args, 'progress', false);

            assertEqual(nh.origin.fa, ds.fa);
            assertEqual(nh.origin.a, ds.a);

            assertEqual(nh.fa.sizes, ones(1, n_keep));
            assertEqual(nh.a, ds.a);

            time_radius = abs(join_time);
            for k = 1:numel(nh.neighbors)
                idx = nh.neighbors{k};

                delta_xyz = sqrt(sum(bsxfun(@minus, fa_ijk(:, k), fa_ijk).^2, 1));
                delta_t = abs(fa_t(k) - fa_t);

                msk = delta_xyz <= 1.9 & delta_t <= time_radius;
                assertEqual(sort(idx), find(msk));
            end

        end
    end

function test_meeg_cluster_neighborhood_unknown_eeg_channels
    if cosmo_skip_test_if_no_external('fieldtrip')
        return
    end

    ds_orig = cosmo_synthetic_dataset('type', 'meeg', ...
                                      'sens', 'eeg1005', ...
                                      'size', 'normal');

    for with_unknown_channel = [false, true]
        ds = ds_orig;
        if with_unknown_channel
            nchan = max(ds.fa.chan);
            idx = ceil(rand() * nchan);
            ds.a.fdim.values{1}{idx} = 'foo';
        end

        nh = cosmo_cluster_neighborhood(ds, 'progress', false);

        if with_unknown_channel
            empty_msk = ds.fa.chan == idx;
            assert(any(empty_msk));
        else
            empty_msk = false(size(ds.fa.chan));
        end

        count = cellfun(@numel, nh.neighbors);
        assert(all(count(empty_msk) == 0));
        assert(all(count(~empty_msk) > 0));
    end

function test_cluster_neighborhood_exceptions
    ds = cosmo_synthetic_dataset();
    aet = @(varargin)assertExceptionThrown( ...
                                           @()cosmo_cluster_neighborhood(varargin{:}, ...
                                                                         'progress', false), '');

    aet(ds, 'foo');
    aet(ds, 'fmri', -1);
    aet(ds, 'fmri', true);
    aet(ds, 'fmri', [1 1]);

    ds.a.fdim.labels{2} = 'foo';
    ds.fa.foo = ds.fa.j;
    aet(ds);

    ds2 = cosmo_synthetic_dataset('type', 'meeg');
    ds2.a.fdim.labels{1} = 'freq';
    ds2.fa.freq = ones(1, 6);

    aet(ds2, 'freq');
    aet(ds2, 'freq', 2);
    aet(ds2, 'freq', NaN);
    aet(ds2, 'freq', [true true]);