test dim generalization measure

function test_suite=test_dim_generalization_measure()
% tests for cosmo_dim_generalization_measure
%
% #   For CoSMoMVPA's copyright information and license terms,   #
% #   see the COPYING file distributed with CoSMoMVPA.           #
    try % assignment of 'localfunctions' is necessary in Matlab >= 2016
        test_functions=localfunctions();
    catch % no problem; early Matlab versions can use initTestSuite fine
    end
    initTestSuite;

function test_dim_generalization_measure_basics
    aet=@(varargin)assertExceptionThrown(@()...
                cosmo_dim_generalization_measure(varargin{:}),'');

    % error on empty input
    aet(struct());
    ds=struct();
    ds.samples=0;
    aet(ds);

    ds=cosmo_synthetic_dataset('type','meeg');
    ds=cosmo_stack({ds,cosmo_slice(ds,1:2,2)},2);

     % four time points, two channels
    ds.fa.chan=[1 2 1 2 1 2 1 2];
    ds.fa.time=[1 1 2 2 3 3 4 4];
    ds.a.fdim.values{1}{end}='foochan';
    ds.a.fdim.values{2}=[-1 0 1 2];
    cosmo_check_dataset(ds);
    opt=struct();
    opt.progress=false;
    opt.measure=@delta_measure;
    aet(ds,opt);
    aet(ds,'dimension','time');
    opt.dimension='time';
    aet(ds,opt);


    ds=cosmo_dim_transpose(ds,'time',1);

    % measure must be a function handle
    aet(ds,'dimension','time','measure','foo');

    % chunks are required
    chunks=ds.sa.chunks;
    ds.sa=rmfield(ds.sa,'chunks');
    aet(ds,opt);
    ds.sa.chunks=chunks;

    % chunks must be 1 and 2, not 1, 2 and 3
    aet(ds,opt)

    ds.sa.chunks=ds.sa.targets;
    ds.sa.targets=chunks;

    % partitions not allowed
    aet(ds,opt,'partitions',cosmo_nfold_partitioner(ds));

    ds.samples=bsxfun(@plus,(ds.fa.chan-1)*12,...
                    6*(ds.sa.time-1)+3*(ds.sa.chunks-1)+ds.sa.targets);

    ds.a.sdim.values{1}(end+1)=2;
    tr_ds=cosmo_slice(ds,ds.sa.chunks==1);
    te_ds=cosmo_slice(ds,repmat(find(ds.sa.chunks==2),2,1));
    te_ds.sa.time=te_ds.sa.time+1;
    ds=cosmo_stack({tr_ds,te_ds});

    for radius=0:1
        unq_tr_time=unique(tr_ds.sa.time)';
        unq_te_time=unique(te_ds.sa.time)';

        ntime=numel(unq_tr_time)*numel(unq_te_time);
        expected_result_cell=cell(ntime,1);

        pos=0;
        for k=(1+radius):(numel(unq_tr_time)-radius)
            tr_time=unq_tr_time(k);
            tr=cosmo_slice(tr_ds,abs(tr_ds.sa.time-tr_time)<=radius);
            tr_tr=cosmo_dim_transpose(tr,'time',2);
            for j=(1+radius):(numel(unq_te_time)-radius)
                te_time=unq_te_time(j);
                te=cosmo_slice(te_ds,abs(te_ds.sa.time-te_time)<=radius);

                te_tr=cosmo_dim_transpose(te,'time',2);


                both=cosmo_stack({tr_tr,te_tr},1,'drop_nonunique');
                both.a.fdim.values=both.a.fdim.values(1);
                both.a.fdim.labels=both.a.fdim.labels(1);
                pos=pos+1;

                res=delta_measure(both);
                e=ones(size(res.samples));
                res.sa.train_time=e*k;
                res.sa.test_time=e*j;
                expected_result_cell{pos}=res;
            end
        end

        expected_result=cosmo_stack(expected_result_cell(1:pos),1);
        expected_result.a.sdim.labels=cell(1,2);
        expected_result.a.sdim.labels{1}='train_time';
        expected_result.a.sdim.labels{2}='test_time';

        tr_dim=ds.a.sdim.values{1}(unq_tr_time);
        te_dim=ds.a.sdim.values{1}(unq_te_time);

        expected_result.a.sdim.values=cell(1,2);
        expected_result.a.sdim.values{1}=tr_dim(:);
        expected_result.a.sdim.values{2}=te_dim(:);

        expected_result=cosmo_dim_prune(expected_result);

        result=cosmo_dim_generalization_measure(ds,opt,'radius',radius);
        assertEqual(result, expected_result);
    end

    % result should be unaffected by permutation of the samples
    nsamples=size(ds.samples,1);
    rp=randperm(nsamples);
    ds_perm=cosmo_slice(ds,rp);
    assertFalse(isequal(ds_perm,ds));

    opt.radius=1;
    assertExceptionThrown(@()cosmo_dim_generalization_measure(...
                                                    ds_perm,opt),'')
    %result_perm=cosmo_dim_generalization_measure(ds_perm,opt);
    %assertEqual(result_perm,result);

    % try with correlation measure
    ds=cosmo_stack({ds,ds},2);
    ds.samples=randn(size(ds.samples));

    opt.radius=0;
    opt.measure=@cosmo_correlation_measure;
    opt.output='correlation';

    % avoid Fisher transformation warning
    warning_state=cosmo_warning();
    cleaner=onCleanup(@()cosmo_warning(warning_state));
    cosmo_warning('off');
    result=cosmo_dim_generalization_measure(ds,opt);


    ds1=cosmo_slice(ds,ds.sa.chunks==1 & ds.sa.time==1);
    ds2=cosmo_slice(ds,ds.sa.chunks==2 & ds.sa.time==3);
    c=opt.measure(cosmo_stack({ds1,ds2}),opt);

    result1=cosmo_slice(result,result.sa.train_time==1 & ...
                                        result.sa.test_time==2);
    assertElementsAlmostEqual(c.samples,result1.samples);
    assertEqual(result1.sa.half1,c.sa.half1);
    assertEqual(result1.sa.half2,c.sa.half2);

    % try with crossvalidation measure
    % swap chunks to get two samples in each class in the training set
    ds.sa.chunks=3-ds.sa.chunks;
    ds1=cosmo_slice(ds,ds.sa.chunks==2 & ds.sa.time==1);
    ds2=cosmo_slice(ds,ds.sa.chunks==1 & ds.sa.time==3);
    opt.measure=@cosmo_crossvalidation_measure;
    opt.output='winner_predictions';

    if cosmo_wtf('is_matlab')
        err_id='MATLAB:nonExistentField';
    else
        err_id='Octave:invalid-indexing';
    end
    assertExceptionThrown(@()...
            cosmo_dim_generalization_measure(ds,opt),err_id);

    opt.classifier=@cosmo_classify_lda;
    result=cosmo_dim_generalization_measure(ds,opt);

    ds_tiny=cosmo_stack({ds1,ds2});
    opt.partitions=cosmo_nchoosek_partitioner(ds_tiny,1,'chunks',2);
    r=opt.measure(ds_tiny,opt);
    ones_=ones(size(r.samples,1),1);
    r.sa.test_time=ones_*1;
    r.sa.train_time=ones_*2;
    r.sa=rmfield(r.sa,'time');

    result1=cosmo_slice(result,result.sa.train_time==2 & ...
                                        result.sa.test_time==1);
    result1.sa=rmfield(result1.sa,'transpose_ids');
    r=set_nan_samples_unique_sa(r);
    result1=set_nan_samples_unique_sa(result1);

    mp=cosmo_align(r.sa,result1.sa);
    assertEqual(r.samples(mp),result1.samples);

    % try with unbalanced partitions
    opt.classifier=@my_stupid_classifier;
    ds.sa.orig_targets=ds.sa.targets;
    ds.sa.targets(ds.sa.targets==2)=3;

    ds1=cosmo_slice(ds,ds.sa.chunks==2 & ds.sa.time==1);
    ds2=cosmo_slice(ds,ds.sa.chunks==1 & ds.sa.time==3);
    ds_tiny=cosmo_stack({ds1,ds2});

    opt.partitions=cosmo_nchoosek_partitioner(ds_tiny,1,'chunks',2);
    opt.partitions=cosmo_balance_partitions(opt.partitions,ds_tiny);
    r=opt.measure(ds_tiny,opt);
    r.sa.test_time=ones_*1;
    r.sa.train_time=ones_*2;
    r.sa=rmfield(r.sa,'time');

    opt=rmfield(opt,'partitions');
    result=cosmo_dim_generalization_measure(ds,opt);
    result1=cosmo_slice(result,result.sa.train_time==2 & ...
                                        result.sa.test_time==1);
    result1.sa=rmfield(result1.sa,'transpose_ids');

    r_msk=~isnan(r.samples);
    result1_msk=~isnan(result1.samples);

    r=cosmo_slice(r,r_msk);
    result1=cosmo_slice(result1,result1_msk);

    mp=cosmo_align(r.sa,result1.sa);
    assertEqual(r.samples(mp),result1.samples);

function ds=set_nan_samples_unique_sa(ds)
    nan_msk=isnan(ds.samples);
    nsamples=numel(nan_msk);
    ds.sa.attr=NaN(size(ds.samples));
    ds.sa.attr(nan_msk)=nsamples+(1:sum(nan_msk));

function pred=my_stupid_classifier(x,y,z,unused)
    [foo,i]=sort(x(:));
    unq=unique(y);
    pred=unq(mod(i(1:size(z,1)),numel(unq))+1);

function z=delta_func(x,y)
    z_mat=bsxfun(@minus,mean(x,1),mean(y,1)');
    z=z_mat(:);

function x=delta_measure(ds,unused)
    msk=ds.sa.chunks==1;

    x=cosmo_slice(ds,msk);
    y=cosmo_slice(ds,~msk);

    x.samples=delta_func(x.samples,y.samples);
    x.sa=struct();
    x.sa.mu=abs(x.samples);
    x.a=rmfield(x.a,'fdim');
    x=rmfield(x,'fa');