function test_suite=test_dim_generalization_measure()
% tests for cosmo_dim_generalization_measure
%
% # For CoSMoMVPA's copyright information and license terms, #
% # see the COPYING file distributed with CoSMoMVPA. #
try % assignment of 'localfunctions' is necessary in Matlab >= 2016
test_functions=localfunctions();
catch % no problem; early Matlab versions can use initTestSuite fine
end
initTestSuite;
function test_dim_generalization_measure_basics
aet=@(varargin)assertExceptionThrown(@()...
cosmo_dim_generalization_measure(varargin{:}),'');
% error on empty input
aet(struct());
ds=struct();
ds.samples=0;
aet(ds);
ds=cosmo_synthetic_dataset('type','meeg');
ds=cosmo_stack({ds,cosmo_slice(ds,1:2,2)},2);
% four time points, two channels
ds.fa.chan=[1 2 1 2 1 2 1 2];
ds.fa.time=[1 1 2 2 3 3 4 4];
ds.a.fdim.values{1}{end}='foochan';
ds.a.fdim.values{2}=[-1 0 1 2];
cosmo_check_dataset(ds);
opt=struct();
opt.progress=false;
opt.measure=@delta_measure;
aet(ds,opt);
aet(ds,'dimension','time');
opt.dimension='time';
aet(ds,opt);
ds=cosmo_dim_transpose(ds,'time',1);
% measure must be a function handle
aet(ds,'dimension','time','measure','foo');
% chunks are required
chunks=ds.sa.chunks;
ds.sa=rmfield(ds.sa,'chunks');
aet(ds,opt);
ds.sa.chunks=chunks;
% chunks must be 1 and 2, not 1, 2 and 3
aet(ds,opt)
ds.sa.chunks=ds.sa.targets;
ds.sa.targets=chunks;
% partitions not allowed
aet(ds,opt,'partitions',cosmo_nfold_partitioner(ds));
ds.samples=bsxfun(@plus,(ds.fa.chan-1)*12,...
6*(ds.sa.time-1)+3*(ds.sa.chunks-1)+ds.sa.targets);
ds.a.sdim.values{1}(end+1)=2;
tr_ds=cosmo_slice(ds,ds.sa.chunks==1);
te_ds=cosmo_slice(ds,repmat(find(ds.sa.chunks==2),2,1));
te_ds.sa.time=te_ds.sa.time+1;
ds=cosmo_stack({tr_ds,te_ds});
for radius=0:1
unq_tr_time=unique(tr_ds.sa.time)';
unq_te_time=unique(te_ds.sa.time)';
ntime=numel(unq_tr_time)*numel(unq_te_time);
expected_result_cell=cell(ntime,1);
pos=0;
for k=(1+radius):(numel(unq_tr_time)-radius)
tr_time=unq_tr_time(k);
tr=cosmo_slice(tr_ds,abs(tr_ds.sa.time-tr_time)<=radius);
tr_tr=cosmo_dim_transpose(tr,'time',2);
for j=(1+radius):(numel(unq_te_time)-radius)
te_time=unq_te_time(j);
te=cosmo_slice(te_ds,abs(te_ds.sa.time-te_time)<=radius);
te_tr=cosmo_dim_transpose(te,'time',2);
both=cosmo_stack({tr_tr,te_tr},1,'drop_nonunique');
both.a.fdim.values=both.a.fdim.values(1);
both.a.fdim.labels=both.a.fdim.labels(1);
pos=pos+1;
res=delta_measure(both);
e=ones(size(res.samples));
res.sa.train_time=e*k;
res.sa.test_time=e*j;
expected_result_cell{pos}=res;
end
end
expected_result=cosmo_stack(expected_result_cell(1:pos),1);
expected_result.a.sdim.labels=cell(1,2);
expected_result.a.sdim.labels{1}='train_time';
expected_result.a.sdim.labels{2}='test_time';
tr_dim=ds.a.sdim.values{1}(unq_tr_time);
te_dim=ds.a.sdim.values{1}(unq_te_time);
expected_result.a.sdim.values=cell(1,2);
expected_result.a.sdim.values{1}=tr_dim(:);
expected_result.a.sdim.values{2}=te_dim(:);
expected_result=cosmo_dim_prune(expected_result);
result=cosmo_dim_generalization_measure(ds,opt,'radius',radius);
assertEqual(result, expected_result);
end
% result should be unaffected by permutation of the samples
nsamples=size(ds.samples,1);
rp=randperm(nsamples);
ds_perm=cosmo_slice(ds,rp);
assertFalse(isequal(ds_perm,ds));
opt.radius=1;
assertExceptionThrown(@()cosmo_dim_generalization_measure(...
ds_perm,opt),'')
%result_perm=cosmo_dim_generalization_measure(ds_perm,opt);
%assertEqual(result_perm,result);
% try with correlation measure
ds=cosmo_stack({ds,ds},2);
ds.samples=randn(size(ds.samples));
opt.radius=0;
opt.measure=@cosmo_correlation_measure;
opt.output='correlation';
% avoid Fisher transformation warning
warning_state=cosmo_warning();
cleaner=onCleanup(@()cosmo_warning(warning_state));
cosmo_warning('off');
result=cosmo_dim_generalization_measure(ds,opt);
ds1=cosmo_slice(ds,ds.sa.chunks==1 & ds.sa.time==1);
ds2=cosmo_slice(ds,ds.sa.chunks==2 & ds.sa.time==3);
c=opt.measure(cosmo_stack({ds1,ds2}),opt);
result1=cosmo_slice(result,result.sa.train_time==1 & ...
result.sa.test_time==2);
assertElementsAlmostEqual(c.samples,result1.samples);
assertEqual(result1.sa.half1,c.sa.half1);
assertEqual(result1.sa.half2,c.sa.half2);
% try with crossvalidation measure
% swap chunks to get two samples in each class in the training set
ds.sa.chunks=3-ds.sa.chunks;
ds1=cosmo_slice(ds,ds.sa.chunks==2 & ds.sa.time==1);
ds2=cosmo_slice(ds,ds.sa.chunks==1 & ds.sa.time==3);
opt.measure=@cosmo_crossvalidation_measure;
opt.output='winner_predictions';
if cosmo_wtf('is_matlab')
err_id='MATLAB:nonExistentField';
else
err_id='Octave:invalid-indexing';
end
assertExceptionThrown(@()...
cosmo_dim_generalization_measure(ds,opt),err_id);
opt.classifier=@cosmo_classify_lda;
result=cosmo_dim_generalization_measure(ds,opt);
ds_tiny=cosmo_stack({ds1,ds2});
opt.partitions=cosmo_nchoosek_partitioner(ds_tiny,1,'chunks',2);
r=opt.measure(ds_tiny,opt);
ones_=ones(size(r.samples,1),1);
r.sa.test_time=ones_*1;
r.sa.train_time=ones_*2;
r.sa=rmfield(r.sa,'time');
result1=cosmo_slice(result,result.sa.train_time==2 & ...
result.sa.test_time==1);
result1.sa=rmfield(result1.sa,'transpose_ids');
r=set_nan_samples_unique_sa(r);
result1=set_nan_samples_unique_sa(result1);
mp=cosmo_align(r.sa,result1.sa);
assertEqual(r.samples(mp),result1.samples);
% try with unbalanced partitions
opt.classifier=@my_stupid_classifier;
ds.sa.orig_targets=ds.sa.targets;
ds.sa.targets(ds.sa.targets==2)=3;
ds1=cosmo_slice(ds,ds.sa.chunks==2 & ds.sa.time==1);
ds2=cosmo_slice(ds,ds.sa.chunks==1 & ds.sa.time==3);
ds_tiny=cosmo_stack({ds1,ds2});
opt.partitions=cosmo_nchoosek_partitioner(ds_tiny,1,'chunks',2);
opt.partitions=cosmo_balance_partitions(opt.partitions,ds_tiny);
r=opt.measure(ds_tiny,opt);
r.sa.test_time=ones_*1;
r.sa.train_time=ones_*2;
r.sa=rmfield(r.sa,'time');
opt=rmfield(opt,'partitions');
result=cosmo_dim_generalization_measure(ds,opt);
result1=cosmo_slice(result,result.sa.train_time==2 & ...
result.sa.test_time==1);
result1.sa=rmfield(result1.sa,'transpose_ids');
r_msk=~isnan(r.samples);
result1_msk=~isnan(result1.samples);
r=cosmo_slice(r,r_msk);
result1=cosmo_slice(result1,result1_msk);
mp=cosmo_align(r.sa,result1.sa);
assertEqual(r.samples(mp),result1.samples);
function ds=set_nan_samples_unique_sa(ds)
nan_msk=isnan(ds.samples);
nsamples=numel(nan_msk);
ds.sa.attr=NaN(size(ds.samples));
ds.sa.attr(nan_msk)=nsamples+(1:sum(nan_msk));
function pred=my_stupid_classifier(x,y,z,unused)
[foo,i]=sort(x(:));
unq=unique(y);
pred=unq(mod(i(1:size(z,1)),numel(unq))+1);
function z=delta_func(x,y)
z_mat=bsxfun(@minus,mean(x,1),mean(y,1)');
z=z_mat(:);
function x=delta_measure(ds,unused)
msk=ds.sa.chunks==1;
x=cosmo_slice(ds,msk);
y=cosmo_slice(ds,~msk);
x.samples=delta_func(x.samples,y.samples);
x.sa=struct();
x.sa.mu=abs(x.samples);
x.a=rmfield(x.a,'fdim');
x=rmfield(x,'fa');