function test_suite=test_cross_neighborhood()
% tests for cosmo_cross_neighborhood
%
% # For CoSMoMVPA's copyright information and license terms, #
% # see the COPYING file distributed with CoSMoMVPA. #
try % assignment of 'localfunctions' is necessary in Matlab >= 2016
test_functions=localfunctions();
catch % no problem; early Matlab versions can use initTestSuite fine
end
initTestSuite;
function test_cross_neighborhood_time_freq()
helper_test_cross_neighborhood(false);
function test_cross_neighborhood_chan_time_freq()
if cosmo_skip_test_if_no_external('fieldtrip')
return;
end
helper_test_cross_neighborhood(true);
function helper_test_cross_neighborhood(can_use_chan_nbrhood)
ds_full=cosmo_synthetic_dataset('type','timefreq','size','big');
msk=cosmo_match(ds_full.fa.chan,@(x)x<20);
ds_full=cosmo_slice(ds_full,msk,2);
ds_full=cosmo_dim_prune(ds_full);
% make it sparse
nfeatures=size(ds_full.samples,2);
nkeep=round(nfeatures/4);
fdim_values=ds_full.a.fdim.values;
nchan=numel(fdim_values{1});
nfreq=numel(fdim_values{2});
ntime=numel(fdim_values{3});
while true
rp=randperm(nfeatures);
ds=cosmo_slice(ds_full,rp(1:nkeep),2);
ds=cosmo_dim_prune(ds);
n=numel(ds.a.fdim.values{1});
ds.a.fdim.values{1}=ds.a.fdim.values{1}(randperm(n));
if numel(unique(ds.fa.chan))==nchan && ...
numel(unique(ds.fa.freq))==nfreq && ...
numel(unique(ds.fa.time))==ntime
break
end
end
nfeatures=size(ds.samples,2);
% define neighborhoods
freq_nbrhood=cosmo_interval_neighborhood(ds,'freq','radius',2);
time_nbrhood=cosmo_interval_neighborhood(ds,'time','radius',1);
if can_use_chan_nbrhood
chan_nbrhood=cosmo_meeg_chan_neighborhood(ds,'count',5,...
'chantype','all','label','dataset');
else
chan_nbrhood='dummy';
end
all_nbrhoods={chan_nbrhood, freq_nbrhood, time_nbrhood};
ndim=numel(all_nbrhoods);
dim_labels={'chan';'freq';'time'};
ntest=5; % number of positions to test
for i=7:-1:1
use_chan=i<=4;
use_freq=mod(i,2)==1;
use_time=mod(ceil(i/2),2)==1;
if ~can_use_chan_nbrhood && use_chan
% no support for channel neighborhood, skip
continue;
end
use_dim_msk=[use_chan;use_freq;use_time];
nbrhood=cosmo_cross_neighborhood(ds,all_nbrhoods(use_dim_msk),...
'progress',false);
assertEqual(nbrhood.a.fdim.labels,dim_labels(use_dim_msk));
assertEqual(fieldnames(nbrhood.fa),dim_labels(use_dim_msk));
n=numel(nbrhood.neighbors);
rp=randperm(n);
for iter=1:min(n,ntest)
pos=rp(iter);
% verify neighborhoods in ds
ds_fa=cosmo_slice(ds.fa,nbrhood.neighbors{pos},2,'struct');
nbr_fa=cosmo_slice(nbrhood.fa,pos,2,'struct');
nbr_msk=true(1,nfeatures);
for dim=1:ndim
dim_label=dim_labels{dim};
if use_dim_msk(dim)
dim_nbrhood=all_nbrhoods{dim};
j=find(dim_nbrhood.fa.(dim_label)==nbr_fa.(dim_label));
assert(numel(j)==1);
m=false(1,nfeatures);
m(dim_nbrhood.neighbors{j})=true;
else
assert(~isfield(nbrhood.fa,dim_label))
m=true(1,nfeatures);
end
nbr_msk=nbr_msk & m;
fa=cosmo_slice(ds.fa.(dim_label),m,2);
assert(isempty(setdiff(ds_fa.(dim_label),fa)));
end
assertEqual(nbrhood.neighbors{pos},find(nbr_msk));
% test agreement between the crossed nbrhood and the
% individual neighborhoods
dim_nbr_msk=true(1,nfeatures);
dim_pos=0;
for dim=1:ndim
if ~use_dim_msk(dim)
continue;
end
dim_pos=dim_pos+1;
dim_label=dim_labels{dim};
nbr_fa=nbrhood.fa.(dim_label);
nbr_values=nbrhood.a.fdim.values{dim_pos}(nbr_fa(pos));
dim_nbrhood=all_nbrhoods{dim};
dim_nbr_values=dim_nbrhood.a.fdim.values{1};
nbr_msk=cosmo_match(dim_nbr_values,nbr_values);
m=false(size(dim_nbr_msk));
m(dim_nbrhood.neighbors{nbr_msk})=true;
dim_nbr_msk=dim_nbr_msk & m;
end
assertEqual(nbrhood.neighbors{pos},find(dim_nbr_msk));
end
end
function test_cross_neighborhood_transpose
opt=struct();
opt.progress=false;
ds=cosmo_synthetic_dataset('type','timefreq','size','normal');
ds=cosmo_dim_remove(ds,'chan');
nh_time=cosmo_interval_neighborhood(ds,'time','radius',1);
nh_freq=cosmo_interval_neighborhood(ds,'freq','radius',0);
nh=cosmo_cross_neighborhood(ds,{nh_freq,nh_time},opt);
cp=cosmo_cartprod(repmat({[false,true]},4,1));
n=size(cp,1);
for k=1:n
t_label=cp{k,1};
t_value=cp{k,2};
t_elem1=cp{k,3};
t_elem2=cp{k,4};
ds2=ds;
if t_label
ds2.a.fdim.labels=ds2.a.fdim.labels';
end
if t_value
ds2.a.fdim.values=ds2.a.fdim.values';
end
if t_elem1
ds2.a.fdim.values{1}=ds2.a.fdim.values{1}';
end
if t_elem2
ds2.a.fdim.values{2}=ds2.a.fdim.values{2}';
end
nh2_time=cosmo_interval_neighborhood(ds2,'time','radius',1);
nh2_freq=cosmo_interval_neighborhood(ds2,'freq','radius',0);
nh2=cosmo_cross_neighborhood(ds2,{nh2_freq,nh2_time},opt);
assertEqual(nh2.a,nh.a);
assertEqual(nh2.fa,nh.fa);
assertEqual(nh2.neighbors,nh.neighbors);
end
function test_cross_neighborhood_exceptions()
ds=cosmo_synthetic_dataset('type','meeg','size','big');
time_nbrhood=cosmo_interval_neighborhood(ds,'time','radius',1);
% test exceptions
aet=@(varargin)assertExceptionThrown(@()...
cosmo_cross_neighborhood(varargin{:}),'');
aet(ds,{});
aet(ds,struct);
aet(ds,struct,time_nbrhood);
aet(ds,time_nbrhood,time_nbrhood);
aet(ds,ds);
% should be fine
assertEqual(cosmo_cross_neighborhood(ds,{time_nbrhood}),time_nbrhood);
% no values too big
time_nbrhood2=time_nbrhood;
time_nbrhood2.neighbors{1}=1e9;
aet(ds,{time_nbrhood2});
% non-integers not supported
time_nbrhood2=time_nbrhood;
time_nbrhood2.neighbors{1}=1.5;
aet(ds,{time_nbrhood2});
% illegal labels
time_nbrhood2=time_nbrhood;
time_nbrhood2.a.fdim.labels{1}='foo';
time_nbrhood2.fa.foo=time_nbrhood.fa.time;
aet(ds,{time_nbrhood2});
% duplicate labels
aet(ds,{time_nbrhood,time_nbrhood});
function test_cross_neighborhood_unsorted_neighbors
ds=cosmo_synthetic_dataset();
nh=cosmo_interval_neighborhood(ds,'i','radius',1);
nh_unsorted=nh;
nh_unsorted.neighbors=cellfun(@(x)x(randperm(numel(x))),...
nh_unsorted.neighbors,...
'UniformOutput',false);
assertEqual(nh,cosmo_cross_neighborhood(ds,{nh_unsorted}));
function test_cross_neighborhood_progress()
if cosmo_skip_test_if_no_external('!evalc')
return;
end
ds=cosmo_synthetic_dataset();
nh1=cosmo_interval_neighborhood(ds,'i','radius',0);
nh2=cosmo_interval_neighborhood(ds,'j','radius',0);
f=@()cosmo_cross_neighborhood(ds,{nh1,nh2});
res=evalc('f();');
assert(~isempty(strfind(res,'[####################]')));
assert(~isempty(strfind(res,'crossing neighborhoods')));
function test_warning_weird_dimension_order
if cosmo_skip_test_if_no_external('fieldtrip')
return;
end
ds=cosmo_synthetic_dataset('type','timefreq','size','big');
keep=ds.fa.chan<20;
ds=cosmo_slice(ds,keep,2);
ds=cosmo_dim_prune(ds);
nh_time=cosmo_interval_neighborhood(ds,'time','radius',0);
nh_freq=cosmo_interval_neighborhood(ds,'freq','radius',0);
nh_chan=cosmo_meeg_chan_neighborhood(ds,'count',4,...
'chantype','meg_planar',...
'label','dataset');
nh_cell={nh_chan,nh_freq,nh_time};
for ndim=1:3
idxs=nchoosek(1:3,ndim);
order=perms(1:ndim);
for k=1:size(idxs,1)
for j=1:size(order,1)
idx=idxs(k,order(j,:));
nh_sel=nh_cell(idx);
is_weird_order=~issorted(idx);
func=@()cosmo_cross_neighborhood(ds,nh_sel,...
'progress',false);
assert_warning_being_shown(func,is_weird_order);
end
end
end
function assert_warning_being_shown(func, flag)
warning_state=cosmo_warning();
cleaner=onCleanup(@()cosmo_warning(warning_state));
cosmo_warning('reset');
cosmo_warning('off');
% initially must be empty
s=cosmo_warning();
assertEqual(s.shown_warnings,{});
func();
s=cosmo_warning();
if flag
% must have warning added
assertTrue(iscellstr(s.shown_warnings));
assertTrue(numel(s.shown_warnings)>0);
else
% still without warning
assertEqual(s.shown_warnings,{});
end