test surficial neighborhood

function test_suite = test_surficial_neighborhood
% tests for cosmo_surficial_neighborhood
%
% #   For CoSMoMVPA's copyright information and license terms,   #
% #   see the COPYING file distributed with CoSMoMVPA.           #
    try % assignment of 'localfunctions' is necessary in Matlab >= 2016
        test_functions=localfunctions();
    catch % no problem; early Matlab versions can use initTestSuite fine
    end
    initTestSuite;

function test_surficial_neighborhood_surface_dijkstra
    if cosmo_skip_test_if_no_external('surfing')
        return;
    end
    warning_state=warning();
    warning_resetter=onCleanup(@()warning(warning_state));
    warning('off');

    opt=struct();
    opt.progress=false;

    ds=cosmo_synthetic_dataset('type','surface');

    [vertices,faces]=get_synthetic_surface();

    % dijkstra neighborhood fixed number of voxels
    args={{vertices,faces},'count',4,'metric','dijkstra',opt};
    nh1=cosmo_surficial_neighborhood(ds,args{:});
    assertFalse(isfield(nh1.a,'vol'));
    assert_equal_cell(nh1.neighbors,{ [ 1 2 4 3 ]
                                        [ 2 1 3 5 ]
                                        [ 3 2 6 5 ]
                                        [ 4 1 5 2 ]
                                        [ 5 2 4 6 ]
                                        [ 6 3 5 2 ] });

    assertEqual(nh1.fa.radius,[2 1 sqrt(2) sqrt(2) 1 2]);
    assertEqual(nh1.fa.node_indices,1:6);
    check_area(vertices,faces,nh1);

    args={{vertices,faces},'radius',2.5,'metric','dijkstra',opt};
    nh2=cosmo_surficial_neighborhood(ds,args{:});

    assert_equal_cell(nh2.neighbors,{[1 2 3 4 5];
                                     [1 2 3 4 5 6];
                                     [1 2 3 4 5 6];
                                     [1 2 3 4 5 6];
                                     [1 2 3 4 5 6];
                                     [2 3 4 5 6]; });
    assertEqual(nh2.fa.radius,[2,2,1+sqrt(2),1+sqrt(2),2,2]);
    assertEqual(nh2.fa.node_indices,1:6);
    check_partial_neighborhood(ds,nh2,args);
    check_area(vertices,faces,nh2);

    args{1}{1}([2,6],:)=NaN;

    nh3=cosmo_surficial_neighborhood(ds,args{:});
    assert_equal_cell(nh3.neighbors,{[1 4 5];
                                     [];
                                     [3 4 5];
                                     [1 3 4 5];
                                     [1 3 4 5];
                                     []; });
    assertEqual(nh3.fa.radius,[2,NaN,1+sqrt(2),1+sqrt(2),2,NaN]);
    check_partial_neighborhood(ds,nh3,args);
    check_area(args{1}{:},nh3);

    args{2}='count';
    args{3}=3;

    nh4=cosmo_surficial_neighborhood(ds,args{:});
    assert_equal_cell(nh4.neighbors,{[1 4 5];
                                     [];
                                     [3 5 4];
                                     [4 1 5];
                                     [5 4 3];
                                     []; });
    assertEqual(nh4.fa.radius,[2,NaN,1+sqrt(2),1,sqrt(2),NaN]);
    check_area(args{1}{:},nh4);

    args{1}{1}=vertices;
    args{1}{1}([2,5],:)=NaN; % split in two surfaces
    args{3}=2;
    nh5=cosmo_surficial_neighborhood(ds,args{:});
    assert_equal_cell(nh5.neighbors,{[1 4];
                                     [];
                                     [3 6];
                                     [4 1];
                                     [];
                                     [6 3]; });
    assertEqual(nh5.fa.radius,[1,NaN,1,1,NaN,1]);
    check_area(args{1}{:},nh5);

    % throw error when too many nodes asked for
    aet=@(varargin)assertExceptionThrown(@()...
                    cosmo_surficial_neighborhood(varargin{:}),'');
    args{2}='count';
    args{3}=3;

    aet(ds,args{:})



function test_surficial_neighborhood_surface_direct
    if cosmo_skip_test_if_no_external('surfing')
        return;
    end
    warning_state=warning();
    warning_resetter=onCleanup(@()warning(warning_state));
    warning('off');

    opt=struct();
    opt.progress=false;

    ds=cosmo_synthetic_dataset('type','surface');

    [vertices,faces]=get_synthetic_surface();

    % direct neighborhood
    args={{vertices,faces},'direct',true,opt};
    nh3=cosmo_surficial_neighborhood(ds,args{:});
    assert_equal_cell(nh3.neighbors,{ [ 1 2 4 ]
                                        [ 2 3 1 4 5 ]
                                        [ 3 2 5 6 ]
                                        [ 4 1 2 5 ]
                                        [ 5 3 6 2 4 ]
                                        [ 6 3 5 ] });
    assertElementsAlmostEqual(nh3.fa.radius,sqrt([1 2 2 2 2 1]));
    check_partial_neighborhood(ds,nh3,args);

    args{1}{1}([2 5],:)=NaN;
    nh4=cosmo_surficial_neighborhood(ds,args{:});
    assert_equal_cell(nh4.neighbors,{ [ 1 4 ]
                                        []
                                        [ 3 6 ]
                                        [ 1 4 ]
                                        []
                                        [6 3] });

    check_partial_neighborhood(ds,nh4,args);

function test_surficial_neighborhood_surface_geodesic
    if cosmo_skip_test_if_no_external('fast_marching') || ...
            cosmo_skip_test_if_no_external('surfing')
        return;
    end

    warning_state=warning();
    warning_resetter=onCleanup(@()warning(warning_state));
    warning('off');

    opt=struct();
    opt.progress=false;

    ds=cosmo_synthetic_dataset('type','surface');%,'size','normal');

    [vertices,faces]=get_synthetic_surface();

    args={{vertices,faces},'count',4,opt};
    nh=cosmo_surficial_neighborhood(ds,args{:});
    assert_equal_cell(nh.neighbors,{[ 1 2 4 5 ]
                                    [ 2 1 3 5 ]
                                    [ 3 2 6 5 ]
                                    [ 4 1 5 2 ]
                                    [ 5 2 4 6 ]
                                    [ 6 3 5 2 ] });
    assertEqual(nh.fa.node_indices,1:6);
    assertEqual(nh.fa.radius,[sqrt(.5)+1 1 sqrt(2) sqrt(2) 1 sqrt(.5)+1]);

    vertices2=[NaN NaN NaN; vertices;NaN NaN NaN];
    faces2=[faces+1; 1 1 8];
    args={{vertices2,faces2},'count',4,opt};
    nh2=cosmo_surficial_neighborhood(ds,args{:});

    assertEqual(nh2.neighbors,{ zeros(1,0)
                                 [ 2 3 5 6 ]
                                 [ 3 2 4 6 ]
                                 [ 4 3 6 2 ]
                                 [ 5 2 6 3 ]
                                 [ 6 3 5 4 ] });

function test_surficial_neighborhood_volume_geodesic
    if cosmo_skip_test_if_no_external('fast_marching') || ...
            cosmo_skip_test_if_no_external('surfing')
        return;
    end
    warning_state=warning();
    warning_resetter=onCleanup(@()warning(warning_state));
    warning('off');

    opt=struct();
    opt.progress=false;

    ds=cosmo_synthetic_dataset();
    vertices=[-2 0 2 -2 0 2;...
                -1 -1 -1 1 1 1;...
                -1 -1 -1 -1 -1 -1]';
    faces= [ 3 2 3 2;...
                2 1 5 4;...
                5 4 6 5 ]';

    pial=vertices;
    pial(:,3)=pial(:,3)+1;
    white=vertices;
    white(:,3)=white(:,3)-1;
    nh1=cosmo_surficial_neighborhood(ds,{vertices,[-1 1],faces},...
                                    'count',4,opt);
    nh2=cosmo_surficial_neighborhood(ds,{pial,white,faces},...
                                    'count',4,opt);
    assert_equal_cell(nh1.neighbors,{[ 1 2 4 5 ]
                                        [ 1 2 3 5 ]
                                        [ 2 3 5 6 ]
                                        [ 4 1 5 2 ]
                                        [ 5 2 4 6 ]
                                        [ 6 3 5 2 ] });
    assertEqual(nh1.fa.node_indices,1:6);
    assertEqual(nh1,nh2);


function test_surficial_neighborhood_volume_dijkstra
    if cosmo_skip_test_if_no_external('surfing')
        return;
    end
    warning_state=warning();
    warning_resetter=onCleanup(@()warning(warning_state));
    warning('off');

    opt=struct();
    opt.progress=false;

    ds=cosmo_synthetic_dataset();
    vertices=[-2 0 2 -2 0 2;
                -1 -1 -1 1 1 1
                -1 -1 -1 -1 -1 -1]';
    faces= [ 3 2 3 2
                2 1 5 4
                5 4 6 5 ]';

    pial=vertices;
    pial(:,3)=pial(:,3)+1;
    white=vertices;
    white(:,3)=white(:,3)-1;

    args3={{vertices,[-1 1],faces},'metric','dijkstra','count',4,opt};
    args4={{pial,white,faces},'metric','dijkstra','count',4,opt};
    nh3=cosmo_surficial_neighborhood(ds,args3{:});
    nh4=cosmo_surficial_neighborhood(ds,args4{:});
    assert_equal_cell(nh3.neighbors,{ [ 1 2 4 3 ]
                                        [ 2 1 3 5 ]
                                        [ 3 2 6 5 ]
                                        [ 4 1 5 2 ]
                                        [ 5 2 4 6 ]
                                        [ 6 3 5 2 ] });
    assertEqual(nh3.fa.node_indices,1:6);
    assert_equal_cell(nh4.neighbors,nh3.neighbors);
    assertFalse(isfield(nh3.a,'vol'));

    % TODO
    % check_partial_neighborhood(ds,nh3,args3);
    % check_partial_neighborhood(ds,nh3,args3);

function test_surficial_neighborhood_exceptions
    if cosmo_skip_test_if_no_external('surfing')
        return
    end
    ds=cosmo_synthetic_dataset('type','surface');%,'size','normal');
    [vertices,faces]=get_synthetic_surface();

    aet=@(varargin)assertExceptionThrown(@()...
                    cosmo_surficial_neighborhood(varargin{:},...
                                        'progress',false),'');
    aet(ds,{vertices,faces});

    % need surfaces
    aet(ds,{},'radius',2);

    % center_ids not supported for surface dataset
    aet(ds,{vertices,faces},'radius',2,'center_ids',1);

    % cannot have duplicate feature ids
    ds_double=cosmo_stack({ds,ds},2);
    ds_double.a.fdim.values{1}=[1:6,1:6];
    aet(ds_double,{vertices,faces},'radius',2);

    % outside range
    ds_double.a.fdim.values{1}=[1:5,1:5];
    aet(ds_double,{vertices,faces},'radius',2);

    % cannot have fmri and surface dataset combined
    ds2=cosmo_synthetic_dataset();
    ds2.a.fdim.values{end+1}=ds.a.fdim.values{1};
    ds2.a.fdim.labels{end+1}=ds.a.fdim.labels{1};
    ds2.fa.node_indices=ds.fa.node_indices;
    aet(ds2,{vertices,faces},'radius',2);

    % cannot have MEEG dataset
    ds_meeg=cosmo_synthetic_dataset('type','meeg');
    aet(ds_meeg,{vertices,faces},'radius',2);

    % need positive scalar radius
    aet(ds,{vertices,faces},'radius',-1);
    aet(ds,{vertices,faces},'radius',eye(2));


function check_partial_neighborhood(ds,nh,args)
    % see if when have a partial dataset, the neighborbood reflects
    % that too

    nf=size(ds.samples,2);

    rp=randperm(nf);
    keep_count=round(nf*.7);
    keep_sel=rp(1:keep_count);
    keep_all=[keep_sel keep_sel keep_sel];

    ds_sel=cosmo_slice(ds,keep_all,2);

    fdim=ds_sel.a.fdim.values{1};
    rp_fdim=randperm(numel(fdim));
    ds_sel.a.fdim.values{1}=fdim(rp_fdim);

    nh_sel=cosmo_surficial_neighborhood(ds_sel,args{:});

    assertEqual(numel(nh_sel.neighbors), numel(keep_sel));

    assertEqual(nh_sel.a.fdim.labels,nh.a.fdim.labels);
    assertEqual(nh_sel.a.fdim.values{1},nh.a.fdim.values{1}(rp_fdim));
    assertEqual(numel(nh_sel.a.fdim.values),numel(nh.a.fdim.values));

    assertEqual(ds_sel.a,nh_sel.a);

    opt=cosmo_structjoin(args(2:end));


    if isfield(opt,'radius')
        metric=opt.metric;
        metric_arg=opt.radius;
    elseif isfield(opt,'count')
        metric=opt.metric;
        metric_arg=[10 opt.count];
    elseif isfield(opt,'direct')
        metric='direct';
        if opt.direct
            metric_arg=NaN;
        else
            metric_arg=0;
        end
    else
        assert(false);
    end



    faces=args{1}{2};
    n2f=surfing_invertmapping(faces);

    nodes_ds_sel=ds_sel.a.fdim.values{1}(ds_sel.fa.node_indices);
    nodes_ds=ds.a.fdim.values{1}(ds.fa.node_indices);

    nodes_nh_sel=nh_sel.a.fdim.values{1}(nh_sel.fa.node_indices);

    vertices=args{1}{1};


    nvertices=size(vertices,1);
    nodes_kept=cosmo_match(1:nvertices,nodes_ds_sel);
    vertices(~nodes_kept,:)=NaN;

    node_mask=all(isfinite(vertices),2);

    nodes_removed=setdiff(nodes_ds(:)',nodes_ds_sel(:)');
    assertEqual(setxor(nodes_removed,nodes_ds_sel),1:nf);

    nb_sel=nh_sel.neighbors;
    for k=1:numel(nh_sel.neighbors)
        sel_center_node=nodes_nh_sel(k);
        idx=find(nodes_ds==sel_center_node);
        assert(numel(idx)==1);
        center_node=nodes_ds(idx);

        assertEqual(sel_center_node, center_node);

        switch metric
            case 'direct'
                if node_mask(sel_center_node)
                    direct_neighbors=surfing_surface_nbrs(faces',...
                                                            vertices');
                    around_nodes=direct_neighbors(sel_center_node,:);
                    msk=cosmo_match(around_nodes, ...
                                            find(isfinite(vertices(:,1))));
                    % add node itself
                    around_nodes=[sel_center_node,...
                                    around_nodes(msk & around_nodes>0)];
                else
                    around_nodes=[];
                end
            otherwise
                around_nodes=surfing_circleROI(vertices',faces',...
                            sel_center_node,metric_arg,metric,n2f);
        end

        sel_around_nodes=nodes_ds_sel(nb_sel{k});

        if isempty(sel_around_nodes)
            assertTrue(isempty(around_nodes));
        else
            assertEqual(unique(sel_around_nodes),...
                        setdiff(around_nodes, nodes_removed))
        end
    end


function test_surface_subsampling
    if cosmo_skip_test_if_no_external('surfing')
        return;
    end
    vertices=[0 -1 -2 -1  1  2  1  3  4  3;
          0 -2  0  2  2  0 -2  2  0 -2;
          0  0  0  0  0  0  0  0  0  0]';

    faces=[1 1 1 1 1 1 5 8 6  6;
           2 3 4 5 6 7 8 9 9 10;
           3 4 5 6 7 2 6 6 10 7]';

    % make custom volume
    ds=cosmo_synthetic_dataset('size','small');
    cp=cosmo_cartprod({1:7,1:3})';
    ds.fa.i=cp(1,:);
    ds.fa.j=cp(2,:);
    ds.fa.k=ds.fa.i*0+1;

    ds.a.fdim.values={1:7;1:3;1};
    ds.samples=zeros(numel(ds.sa.targets),numel(ds.fa.i));
    ds.a.vol.dim=cellfun(@numel,ds.a.fdim.values)';
    ds.a.vol.mat(2,4)=-4;
    ds.a.vol.mat(3,4)=-1;
    ds.a.vol.mat(1,1)=1;
    ds.a.vol.mat(2,2)=2;
    ds.a.vol.mat(3,3)=1;

    surfs={vertices,faces,[-4 5]};

    opt=struct();
    opt.progress=false;
    opt.radius=3;
    opt.metric='euclidean';
    nh=cosmo_surficial_neighborhood(ds,surfs,opt);

    assertEqual(nh.neighbors,{ [ 2 4 8 10 12 16 18 ]
                             [ 2 4 8 10 ]
                             [ 2 8 10 16 ]
                             [ 8 10 16 18 ]
                             [ 10 12 16 18 20 ]
                             [ 4 6 10 12 14 18 20 ]
                             [ 2 4 6 10 12 ]
                             [ 12 14 18 20 ]
                             [ 6 12 14 20 ]
                             [ 4 6 12 14 ] });
    assertEqual(nh.origin.fa,ds.fa);
    assertEqual(nh.origin.a,ds.a);

    % test subsampling
    subsample=2;
    surfs={vertices,faces,[-4 5],subsample};
    nh2=cosmo_surficial_neighborhood(ds,surfs,opt);
    assertEqual(nh2.neighbors,{ [ 2 4 8 10 ]
                                 [ 2 8 10 16 ]
                                 [ 8 10 16 18 ]
                                 [ 10 12 16 18 20 ]
                                 [ 2 4 6 10 12 ]
                                 [ 12 14 18 20 ]
                                 [ 6 12 14 20 ]
                                 [ 4 6 12 14 ] });

    assertEqual(nh2.origin.fa,ds.fa);
    assertEqual(nh2.origin.a,ds.a);

    % subsampling with pial surface
    pial=bsxfun(@plus,vertices,[0 0 1]);
    white=bsxfun(@plus,vertices,[0 0 -1]);
    [vo,fo]=surfing_subsample_surface(vertices,faces,2,.2,0);
    surfs={pial,white,faces,vo,fo};
    nh3=cosmo_surficial_neighborhood(ds,surfs,opt);
    assertEqual(nh2,nh3);

    % check center ids options
    slice_ids=[5 3 2];
    nh4=cosmo_surficial_neighborhood(ds,surfs,opt,'center_ids',slice_ids);
    nh4_sl=struct();
    nh4_sl.neighbors=nh3.neighbors(slice_ids);
    nh4_sl.fa=cosmo_slice(nh3.fa,slice_ids,2,'struct');
    nh4_sl.a=nh3.a;

    assertEqual(nh4.a.fdim.values{1}(nh4.fa.node_indices),...
                    nh4_sl.a.fdim.values{1}(nh4_sl.fa.node_indices));
    assertEqual(nh4.neighbors,nh4_sl.neighbors);


    % try with file names
    fn_pial=cosmo_make_temp_filename('pial','.asc');
    fn_white=cosmo_make_temp_filename('white','.asc');
    fn_tiny=cosmo_make_temp_filename('tiny','.asc');

    cleaner1=onCleanup(@()delete(fn_pial));
    cleaner2=onCleanup(@()delete(fn_white));
    cleaner3=onCleanup(@()delete(fn_tiny));

    surfing_write(fn_pial, pial, faces);
    surfing_write(fn_white, white, faces);
    surfing_write(fn_tiny, vo, fo);

    surfs={fn_pial,fn_white,fn_tiny};
    nh5=cosmo_surficial_neighborhood(ds,surfs,opt);
    assertEqual(nh2,nh5);

    % should work with alternative voldef
    ds_bad_vol=ds;
    ds_bad_vol.a.vol.mat(:)=NaN;
    ds_bad_vol.a.vol.dim(:)=NaN;
    nh6=cosmo_surficial_neighborhood(ds_bad_vol,surfs,opt,...
                                            'vol_def',ds.a.vol);
    nh6.origin.a.vol=ds.a.vol;
    assertEqual(nh5,nh6);

    % check exceptions
    aet=@(varargin)assertExceptionThrown(@()...
                    cosmo_surficial_neighborhood(varargin{:}),'');

    surfs={fn_pial,fn_pial,fn_tiny};
    aet(ds,surfs,opt);

    white_bad=white;
    white_bad=white_bad(2:end,:);
    aet(ds,{pial,white_bad,faces},opt);

    % missing faces for output surface
    aet(ds,{fn_pial,fn_white,vo},opt);


    % face mismatch
    faces_bad=faces;
    faces_bad=faces_bad(end:-1:1,:);

    surfing_write(fn_white, white, faces_bad);
    aet(ds,{fn_pial,fn_white},opt);



    % too many surf arguments
    aet(ds,{fn_pial,fn_white,fn_tiny,fn_tiny},opt);
    aet(ds,{pial,white,faces,pial,white,white},opt);
    aet(ds,{pial,white,faces,fn_pial,white},opt);

    % surfs are not a cell
    aet(ds,struct,opt);
    aet(ds,{pial,white,{}});


function check_area(vertices,faces,nh)
    assert(isfield(nh.fa,'area'));
    area=surfing_surfacearea(vertices,faces);

    node_idxs=nh.fa.node_indices;
    n_nodes=numel(node_idxs);
    for k=1:n_nodes
        nbr_idxs=node_idxs(nh.neighbors{k});
        expected_area=sum(area(nbr_idxs));
        if isnan(expected_area)
            assertEqual(expected_area,nh.fa.area(k))
        else
            assertElementsAlmostEqual(expected_area,nh.fa.area(k));
        end
    end




function [vertices,faces]=get_synthetic_surface()
    % return the following surface (face indices in [brackets])
    %
    %  1-----2-----3
    %  |    /|    /|
    %  |[2]/ |[1]/ |
    %  |  /  |  /  |
    %  | /[4]| /[3]|
    %  |/    |/    |
    %  4-----5-----6

    vertices=[0 0 0 1 1 1;
                1 2 3 1 2 3;
                0 0 0 0 0 0]';
    faces= [ 3 2 3 2
                2 1 5 4
                5 4 6 5 ]';


function assert_equal_cell(x,y)
    % small helper
    assertEqual(size(x),size(y))
    for k=1:numel(x)
        xk=x{k};
        yk=y{k};
        if isempty(xk)
            assertTrue(isempty(yk));
        else
            assertEqual(sort(xk),sort(yk));
        end
    end