function test_suite = test_average_samples
% tests for cosmo_average_samples
%
% # For CoSMoMVPA's copyright information and license terms, #
% # see the COPYING file distributed with CoSMoMVPA. #
try % assignment of 'localfunctions' is necessary in Matlab >= 2016
test_functions=localfunctions();
catch % no problem; early Matlab versions can use initTestSuite fine
end
initTestSuite;
function test_average_samples_
ds=cosmo_synthetic_dataset();
a=cosmo_average_samples(ds);
assertElementsAlmostEqual(sort(a.samples), sort(ds.samples));
assertElementsAlmostEqual(sort(a.samples(:,3)), sort(ds.samples(:,3)));
a=cosmo_average_samples(ds,'ratio',.5);
assertElementsAlmostEqual(sort(a.samples), sort(ds.samples));
assertElementsAlmostEqual(sort(a.samples(:,3)), sort(ds.samples(:,3)));
% check wrong inputs
aet=@(varargin)assertExceptionThrown(@()...
cosmo_average_samples(varargin{:}),'');
aet(ds,'ratio',.1);
aet(ds,'ratio',3);
aet(ds,'ratio',.5,'count',2);
ds.sa.chunks(:)=1;
a=cosmo_average_samples(ds,'ratio',.5);
cosmo_check_dataset(a);
ds=cosmo_slice(ds,3,2);
ns=size(ds.samples,1);
ds.samples=ds.sa.targets*1000+(1:ns)';
a=cosmo_average_samples(ds,'ratio',.5,'nrep',10);
% no mixing of different targets
delta=a.samples/1000-a.sa.targets;
assertTrue(all(.00099<=delta & delta<.05));
assertElementsAlmostEqual(delta*3000,round(delta*3000));
a=cosmo_average_samples(ds,'count',3,'nrep',10);
% no mixing of different targets
delta=a.samples/1000-a.sa.targets;
assertTrue(all(.00099<=delta & delta<.05));
assertElementsAlmostEqual(delta*3000,round(delta*3000));
function test_average_samples_split_by
plural_singular={'targets','targets';...
'chunks','chunks';...
'subjects','subject';...
'modalities','modality';...
};
n_dim=size(plural_singular,1);
combis=cosmo_cartprod(repmat({{true,false}},n_dim,1)');
for k=1:size(combis,1);
combi=cell2mat(combis(k,:));
opt=struct();
opt.seed=0; % truly random data
for j=1:n_dim
count=ceil(rand()*2+1);
opt.(['n' plural_singular{j,1}])=count;
end
ds=cosmo_synthetic_dataset(opt);
values=cell(n_dim,1);
for j=1:n_dim
if combi(j)
values{j}=ds.sa.(plural_singular{j,2});
end
end
values=values(combi);
if any(combi)
[idx,unq_cell]=cosmo_index_unique(values);
else
idx={1:(size(ds.samples,1))};
end
n_avg=numel(idx);
n_features=size(ds.samples,2);
expected_samples=zeros(n_avg,n_features);
for m=1:n_avg
expected_samples(m,:)=mean(ds.samples(idx{m},:),1);
end
result=cosmo_average_samples(ds,...
'split_by',plural_singular(combi,2));
assertEqual(size(result.samples),size(expected_samples));
delta=bsxfun(@minus,result.samples(:,1),expected_samples(:,1)');
mapping=zeros(1,n_avg);
for m=1:n_avg
[mn,mn_idx]=min(abs(delta(m,:)));
assert(mn<1e-5); % deal with rounding
mapping(mn_idx)=m;
end
assertEqual(sort(mapping),1:n_avg);
result_perm=cosmo_slice(result,mapping);
assertElementsAlmostEqual(result_perm.samples,expected_samples);
pos=0;
for j=1:n_dim
if combi(j)
pos=pos+1;
fn=plural_singular{j,2};
assertEqual(unq_cell{pos},result_perm.sa.(fn));
end
end
% check default result
if isequal(plural_singular(combi),{'targets','chunks'});
default_result=cosmo_average_samples(ds);
assertEqual(result,default_result);
end
end
function test_average_samples_split_by_empty()
ds=cosmo_synthetic_dataset('ntargets',ceil(rand()*5+2),...
'nchunks',ceil(rand()*5+2));
result=cosmo_average_samples(ds,'split_by',{});
assertElementsAlmostEqual(result.samples,mean(ds.samples,1));
function test_average_samples_exceptions
aet=@(varargin)assertExceptionThrown(@()...
cosmo_average_samples(varargin{:}),'');
ds=cosmo_synthetic_dataset('nreps',5);
aet([]);
x=struct();
x.samples=randn(4);
aet(x);
% illegal count
aet(ds,'count',6)
aet(ds,'count',[2 2])
aet(ds,'count',3.5);
aet(ds,'count',0);
% illegal ratio
aet(ds,'ratio',1.2)
aet(ds,'ratio',-0.2);
aet(ds,'ratio',[.5 .5]);
% mutually exclusive
aet(ds,'ratio',.5,'count',2);
aet(ds,'repeats',[2 2])
aet(ds,'repeats',-1);
aet(ds,'resamplings',[2 2]);
aet(ds,'resamplings',-1);
aet(ds,'resamplings',1,'repeats',1);
% not existing field
ds_bad=ds;
ds_bad.sa=rmfield(ds_bad.sa,'targets');
aet(ds_bad);
% illegal split-by arguments
aet(ds,'split_by',[]);
aet(ds,'split_by',struct());
aet(ds,'split_by','foo');
aet(ds,'split_by',{1,2});
function test_average_samples_with_repeats
nchunks=ceil(rand()*4+3);
ntargets=ceil(rand()*4+3);
ncombi_max=ceil(rand()*3+4);
max_cyc=5;
ncombi_min=ceil(ncombi_max/2);
ds=cosmo_synthetic_dataset('nchunks',nchunks,...
'ntargets',ntargets,...
'nreps',ncombi_max);
ds.sa=rmfield(ds.sa,'rep');
sp=cosmo_split(ds,{'targets','chunks'});
n_splits=numel(sp);
% select subset of samples, each with at least ncombi_min repeats
combi_count=zeros(nchunks,ntargets);
for k=1:n_splits
if k==1
% ensure at least one with minimum
nkeep=ncombi_min;
else
nkeep=ncombi_min+floor(rand()*(ncombi_max-ncombi_min));
end
ds_k=cosmo_slice(sp{k},1:nkeep);
ds_k.sa.repeats=(1:nkeep)';
combi_count(ds_k.sa.chunks(1),ds_k.sa.targets(1))=nkeep;
sp{k}=ds_k;
end
assert(all(cellfun(@(x)size(x.samples,1),sp)));
ds=cosmo_stack(sp);
[nsamples,nfeatures]=size(ds.samples);
% bit widths for features, chunks, targets, and repeats
bws=[nfeatures,nchunks,ntargets,ceil(log2(max_cyc+1))+ncombi_max];
% encode features, chunks, targets and repeats into single number
dsb=binarize_ds(ds,bws);
% helper function
check_with=@(args,...
count,...
repeats) check_with_helper(dsb,args,count,repeats,...
nchunks,ntargets,...
ncombi_max,combi_count,...
bws);
for repeats=[1,ceil(rand()*ncombi_max)]
for count=[1,ceil(rand()*ncombi_min)];
check_with({'count',count,'repeats',repeats},...
count,repeats);
end
for ratio=[.5,.3+rand()*.7];
count=round(ratio*min(combi_count(:)));
check_with({'ratio',ratio,'repeats',repeats},...
count,repeats);
end
end
for resamplings=[0,1,2+round(rand()*4)]
count=ceil(rand()*ncombi_min);
if resamplings==0
repeats=floor(ncombi_min/count);
args={'count',count};
else
repeats=floor(resamplings*ncombi_min/count);
args={'count',count,'resamplings',resamplings};
end
check_with(args,count,repeats);
end
function check_with_helper(dsb, args, count, repeats,...
nchunks, ntargets, ncombi_max, combi_count, bws)
mu=cosmo_average_samples(dsb,args{:});
[chunks,targets,ids]=unbinarize_ds(mu, bws, count);
nsamples=size(ids,1);
nfeatures=size(dsb.samples,2);
% chunk, target, repeat count
ctr_count=zeros(nchunks,ntargets,ncombi_max);
% keep track of each target and chunk combination
for j=1:nsamples
for k=1:nfeatures
% select same samples for all features
id=ids{j,k};
if k==1
first_id=id;
else
assertEqual(first_id,id);
end
end
% no repeats
id_sorted=sort(id(:));
assert(all(diff(id_sorted)>0));
% count should match
assertEqual(numel(id),count);
ctr_count(chunks(j),targets(j),id)=...
ctr_count(chunks(j),targets(j),id)+1;
end
% ensure each sample selected about equally often
[nchunks,ntargets]=size(combi_count);
for k=1:nchunks
for j=1:ntargets
c=squeeze(ctr_count(k,j,:));
pre=c(1:combi_count(k,j));
assert(max(pre)-min(pre)<=1);
post=c((combi_count(k,j)+1):end);
assert(all(post==0));
end
end
% check each target and chunk combination was used the correct number
% of times to form the average
ct_count=sum(ctr_count,3);
expected_ct_count=count*repeats*ones(nchunks,ntargets);
assert(isequal(ct_count, expected_ct_count));
function [chunks,targets,ids]=unbinarize_ds(ds, bws, counts)
[nsamples, nfeatures]=size(ds.samples);
ids=cell(nsamples,nfeatures);
chunks=zeros(nsamples,1);
targets=zeros(nsamples,1);
for k=1:nsamples
for j=1:nfeatures
% Decode repeats; multiple repeats can be present.
% As there can be multiple repeats, the averaging is undone
% and then each bit represents just one repeat
v_id=quick_dec2bin(mod(ds.samples(k,j)*counts,...
2^bws(end)),...
bws(end));
ids{k,j}=bws(end)-find(v_id)+1;
% decode chunks, targets, ids
v=decode(floor(ds.samples(k,j)/2^bws(end)),bws(1:(end-1)));
assertEqual(log2(v(1))+1,j);
c=log2(v(2))+1;
t=log2(v(3))+1;
if j==1
chunks(k)=c;
targets(k)=t;
else
assertEqual(c,chunks(k));
assertEqual(t,targets(k));
end
end
end
function bds=binarize_ds(ds, bws)
bds=ds;
[nsamples,nfeatures]=size(ds.samples);
for k=1:nsamples
sa=cosmo_slice(ds.sa,k,1,'struct');
for j=1:nfeatures
vs=[j, sa.chunks sa.targets sa.repeats];
bds.samples(k,j)=encode(vs,bws);
end
end
function p=encode(vs, bws)
% encode several decimal numbers in a single one, through
% encode([X1 ... Xn]) = bin2dec([dec2bin(X1) ... dec2bin(Xn)])
% where bws contains the bit width for each number
n=numel(bws);
assert(numel(vs)==n);
bs=cell(1,n);
for k=1:n
bw=bws(k);
bs{k}=zeros(1,bw);
bs{k}(bw-vs(k)+1)=1;
end
p=quick_bin2dec(cat(2,bs{:}));
function vs=decode(p, bws)
% encode single decimal numbers in multiple ones, through
% decode(P) = [bin2dec(PB1) ... bin2dec(PBn)]
% with PBi the binary representation part of P for each binary
% representation part
arr=quick_dec2bin(p,sum(bws));
c=0;
n=numel(bws);
vs=zeros(1,n);
for k=1:n
offset=bws(k);
vs(k)=quick_bin2dec(arr(c+(1:offset)));
c=c+offset;
end
function arr=quick_dec2bin(x,bw)
% converts decimal number x to array with length bw and all
% values in 0 and 1
assert(round(x)==x);
arr=zeros(1,bw);
xbs=dec2bin(x);
arr(bw-numel(xbs)+1:end)=(xbs=='1');
return
function x=quick_bin2dec(arr)
% convert binary array to decimal number
x=sum(2.^((numel(arr)-1):-1:0).*arr);