test cross neighborhood

function test_suite=test_cross_neighborhood()
% tests for cosmo_cross_neighborhood
%
% #   For CoSMoMVPA's copyright information and license terms,   #
% #   see the COPYING file distributed with CoSMoMVPA.           #
    try % assignment of 'localfunctions' is necessary in Matlab >= 2016
        test_functions=localfunctions();
    catch % no problem; early Matlab versions can use initTestSuite fine
    end
    initTestSuite;

function test_cross_neighborhood_time_freq()
    helper_test_cross_neighborhood(false);

function test_cross_neighborhood_chan_time_freq()
    if cosmo_skip_test_if_no_external('fieldtrip')
        return;
    end
    helper_test_cross_neighborhood(true);

function helper_test_cross_neighborhood(can_use_chan_nbrhood)
    ds_full=cosmo_synthetic_dataset('type','timefreq','size','big');
    msk=cosmo_match(ds_full.fa.chan,@(x)x<20);
    ds_full=cosmo_slice(ds_full,msk,2);
    ds_full=cosmo_dim_prune(ds_full);

    % make it sparse
    nfeatures=size(ds_full.samples,2);
    nkeep=round(nfeatures/4);
    fdim_values=ds_full.a.fdim.values;
    nchan=numel(fdim_values{1});
    nfreq=numel(fdim_values{2});
    ntime=numel(fdim_values{3});

    while true
        rp=randperm(nfeatures);
        ds=cosmo_slice(ds_full,rp(1:nkeep),2);
        ds=cosmo_dim_prune(ds);
        n=numel(ds.a.fdim.values{1});
        ds.a.fdim.values{1}=ds.a.fdim.values{1}(randperm(n));

        if numel(unique(ds.fa.chan))==nchan && ...
                    numel(unique(ds.fa.freq))==nfreq && ...
                    numel(unique(ds.fa.time))==ntime
            break
        end
    end

    nfeatures=size(ds.samples,2);

    % define neighborhoods
    freq_nbrhood=cosmo_interval_neighborhood(ds,'freq','radius',2);
    time_nbrhood=cosmo_interval_neighborhood(ds,'time','radius',1);
    if can_use_chan_nbrhood
        chan_nbrhood=cosmo_meeg_chan_neighborhood(ds,'count',5,...
                                'chantype','all','label','dataset');
    else
        chan_nbrhood='dummy';
    end

    all_nbrhoods={chan_nbrhood, freq_nbrhood, time_nbrhood};
    ndim=numel(all_nbrhoods);
    dim_labels={'chan';'freq';'time'};

    ntest=5;  % number of positions to test

    for i=7:-1:1
        use_chan=i<=4;
        use_freq=mod(i,2)==1;
        use_time=mod(ceil(i/2),2)==1;

        if ~can_use_chan_nbrhood && use_chan
            % no support for channel neighborhood, skip
            continue;
        end

        use_dim_msk=[use_chan;use_freq;use_time];
        nbrhood=cosmo_cross_neighborhood(ds,all_nbrhoods(use_dim_msk),...
                                                'progress',false);
        assertEqual(nbrhood.a.fdim.labels,dim_labels(use_dim_msk));
        assertEqual(fieldnames(nbrhood.fa),dim_labels(use_dim_msk));

        n=numel(nbrhood.neighbors);
        rp=randperm(n);

        for iter=1:min(n,ntest)
            pos=rp(iter);
            % verify neighborhoods in ds

            ds_fa=cosmo_slice(ds.fa,nbrhood.neighbors{pos},2,'struct');
            nbr_fa=cosmo_slice(nbrhood.fa,pos,2,'struct');

            nbr_msk=true(1,nfeatures);

            for dim=1:ndim
                dim_label=dim_labels{dim};
                if use_dim_msk(dim)
                    dim_nbrhood=all_nbrhoods{dim};

                    j=find(dim_nbrhood.fa.(dim_label)==nbr_fa.(dim_label));
                    assert(numel(j)==1);

                    m=false(1,nfeatures);
                    m(dim_nbrhood.neighbors{j})=true;
                else
                    assert(~isfield(nbrhood.fa,dim_label))
                    m=true(1,nfeatures);
                end

                nbr_msk=nbr_msk & m;

                fa=cosmo_slice(ds.fa.(dim_label),m,2);
                assert(isempty(setdiff(ds_fa.(dim_label),fa)));
            end

            assertEqual(nbrhood.neighbors{pos},find(nbr_msk));

            % test agreement between the crossed nbrhood and the
            % individual neighborhoods
            dim_nbr_msk=true(1,nfeatures);
            dim_pos=0;
            for dim=1:ndim
                if ~use_dim_msk(dim)
                    continue;
                end
                dim_pos=dim_pos+1;
                dim_label=dim_labels{dim};
                nbr_fa=nbrhood.fa.(dim_label);
                nbr_values=nbrhood.a.fdim.values{dim_pos}(nbr_fa(pos));

                dim_nbrhood=all_nbrhoods{dim};
                dim_nbr_values=dim_nbrhood.a.fdim.values{1};
                nbr_msk=cosmo_match(dim_nbr_values,nbr_values);

                m=false(size(dim_nbr_msk));
                m(dim_nbrhood.neighbors{nbr_msk})=true;
                dim_nbr_msk=dim_nbr_msk & m;
            end
            assertEqual(nbrhood.neighbors{pos},find(dim_nbr_msk));
        end
    end

function test_cross_neighborhood_transpose
    opt=struct();
    opt.progress=false;
    ds=cosmo_synthetic_dataset('type','timefreq','size','normal');
    ds=cosmo_dim_remove(ds,'chan');

    nh_time=cosmo_interval_neighborhood(ds,'time','radius',1);
    nh_freq=cosmo_interval_neighborhood(ds,'freq','radius',0);

    nh=cosmo_cross_neighborhood(ds,{nh_freq,nh_time},opt);

    cp=cosmo_cartprod(repmat({[false,true]},4,1));
    n=size(cp,1);

    for k=1:n
        t_label=cp{k,1};
        t_value=cp{k,2};
        t_elem1=cp{k,3};
        t_elem2=cp{k,4};

        ds2=ds;

        if t_label
            ds2.a.fdim.labels=ds2.a.fdim.labels';
        end

        if t_value
            ds2.a.fdim.values=ds2.a.fdim.values';
        end

        if t_elem1
            ds2.a.fdim.values{1}=ds2.a.fdim.values{1}';
        end

        if t_elem2
            ds2.a.fdim.values{2}=ds2.a.fdim.values{2}';
        end

        nh2_time=cosmo_interval_neighborhood(ds2,'time','radius',1);
        nh2_freq=cosmo_interval_neighborhood(ds2,'freq','radius',0);

        nh2=cosmo_cross_neighborhood(ds2,{nh2_freq,nh2_time},opt);
        assertEqual(nh2.a,nh.a);
        assertEqual(nh2.fa,nh.fa);
        assertEqual(nh2.neighbors,nh.neighbors);
    end





function test_cross_neighborhood_exceptions()
    ds=cosmo_synthetic_dataset('type','meeg','size','big');
    time_nbrhood=cosmo_interval_neighborhood(ds,'time','radius',1);

    % test exceptions
    aet=@(varargin)assertExceptionThrown(@()...
                        cosmo_cross_neighborhood(varargin{:}),'');

    aet(ds,{});
    aet(ds,struct);
    aet(ds,struct,time_nbrhood);
    aet(ds,time_nbrhood,time_nbrhood);
    aet(ds,ds);

    % should be fine
    assertEqual(cosmo_cross_neighborhood(ds,{time_nbrhood}),time_nbrhood);

    % no values too big
    time_nbrhood2=time_nbrhood;
    time_nbrhood2.neighbors{1}=1e9;
    aet(ds,{time_nbrhood2});

    % non-integers not supported
    time_nbrhood2=time_nbrhood;
    time_nbrhood2.neighbors{1}=1.5;
    aet(ds,{time_nbrhood2});

    % illegal labels
    time_nbrhood2=time_nbrhood;
    time_nbrhood2.a.fdim.labels{1}='foo';
    time_nbrhood2.fa.foo=time_nbrhood.fa.time;
    aet(ds,{time_nbrhood2});

    % duplicate labels
    aet(ds,{time_nbrhood,time_nbrhood});


function test_cross_neighborhood_unsorted_neighbors
    ds=cosmo_synthetic_dataset();
    nh=cosmo_interval_neighborhood(ds,'i','radius',1);

    nh_unsorted=nh;
    nh_unsorted.neighbors=cellfun(@(x)x(randperm(numel(x))),...
                                nh_unsorted.neighbors,...
                                'UniformOutput',false);
    assertEqual(nh,cosmo_cross_neighborhood(ds,{nh_unsorted}));


function test_cross_neighborhood_progress()
    if cosmo_skip_test_if_no_external('!evalc')
        return;
    end

    ds=cosmo_synthetic_dataset();
    nh1=cosmo_interval_neighborhood(ds,'i','radius',0);
    nh2=cosmo_interval_neighborhood(ds,'j','radius',0);
    f=@()cosmo_cross_neighborhood(ds,{nh1,nh2});
    res=evalc('f();');
    assert(~isempty(strfind(res,'[####################]')));
    assert(~isempty(strfind(res,'crossing neighborhoods')));


function test_warning_weird_dimension_order
    if cosmo_skip_test_if_no_external('fieldtrip')
        return;
    end

    ds=cosmo_synthetic_dataset('type','timefreq','size','big');
    keep=ds.fa.chan<20;
    ds=cosmo_slice(ds,keep,2);
    ds=cosmo_dim_prune(ds);

    nh_time=cosmo_interval_neighborhood(ds,'time','radius',0);
    nh_freq=cosmo_interval_neighborhood(ds,'freq','radius',0);
    nh_chan=cosmo_meeg_chan_neighborhood(ds,'count',4,...
                                    'chantype','meg_planar',...
                                    'label','dataset');

    nh_cell={nh_chan,nh_freq,nh_time};
    for ndim=1:3
        idxs=nchoosek(1:3,ndim);
        order=perms(1:ndim);

        for k=1:size(idxs,1)
            for j=1:size(order,1)

                idx=idxs(k,order(j,:));
                nh_sel=nh_cell(idx);

                is_weird_order=~issorted(idx);

                func=@()cosmo_cross_neighborhood(ds,nh_sel,...
                                    'progress',false);
                assert_warning_being_shown(func,is_weird_order);
            end
        end
    end



function assert_warning_being_shown(func, flag)
    warning_state=cosmo_warning();
    cleaner=onCleanup(@()cosmo_warning(warning_state));

    cosmo_warning('reset');
    cosmo_warning('off');

    % initially must be empty
    s=cosmo_warning();
    assertEqual(s.shown_warnings,{});

    func();

    s=cosmo_warning();
    if flag
        % must have warning added
        assertTrue(iscellstr(s.shown_warnings));
        assertTrue(numel(s.shown_warnings)>0);
    else
        % still without warning
        assertEqual(s.shown_warnings,{});
    end