test dim prune

function test_suite = test_dim_prune()
% tests for cosmo_dim_prune
%
% #   For CoSMoMVPA's copyright information and license terms,   #
% #   see the COPYING file distributed with CoSMoMVPA.           #
    try % assignment of 'localfunctions' is necessary in Matlab >= 2016
        test_functions=localfunctions();
    catch % no problem; early Matlab versions can use initTestSuite fine
    end
    initTestSuite;

function test_dim_prune_fmri()
    helper_test_dim_prune();


function test_dim_prune_meeg_timelock()
    helper_test_dim_prune('type','timelock');

function test_dim_prune_meeg_source_mom()
    helper_test_dim_prune('type','source','data_field','mom');

function test_dim_prune_meeg_source_pow()
    helper_test_dim_prune('type','source','data_field','pow');

function test_dim_prune_surface()
    helper_test_dim_prune('type','surface');

function test_dim_prune_exceptions()
    aet=@(varargin)assertExceptionThrown(@()...
                    cosmo_dim_prune(varargin{:}),'');
    aet(struct);
    ds=cosmo_synthetic_dataset();
    aet(ds,3);
    aet(ds,'matrix_labels','pos');
    aet(ds,'dim',3);
    aet(ds,'labels',struct);

function test_dim_prune_default_dim()
    ds=cosmo_synthetic_dataset();
    for dim=1:2
        ds_pruned=cosmo_dim_prune(ds);
        ds_pruned2=cosmo_dim_prune(ds,'dim',dim);
        assertEqual(ds_pruned,ds_pruned2);
    end

function test_dim_prune_pos_sample_dim()
    ds=cosmo_synthetic_dataset('type','source','data_field','mom');
    sdim=struct();
    sdim.labels=ds.a.fdim.labels(1);
    sdim.values=ds.a.fdim.values(1)';
    pos=ds.fa.pos;

    ds_tr=cosmo_dim_remove(ds,'pos');

    ds_tr.a.sdim=sdim;
    nsamples=size(ds.samples,1);
    ds_tr.sa.pos=pos(1:nsamples)';

    ds_tr.sa.pos(ds_tr.sa.pos==1)=2;

    assertExceptionThrown(@()cosmo_dim_prune(ds_tr),'');
    assertExceptionThrown(@()cosmo_dim_prune(ds_tr,'labels',{'pos'}),'');
    ds_tr2=cosmo_dim_prune(ds_tr,'labels',{'time'});
    assertEqual(ds_tr2,ds_tr);


function test_dim_prune_label()
    ds=cosmo_synthetic_dataset();
    ds.fa.i(ds.fa.i==2)=1;

    labels={'i','j','k'};

    for k=1:numel(labels)
        label=labels{k};
        ds_pruned=cosmo_dim_prune(ds,'labels',labels(k));
        if strcmp(label,'i')
            assertFalse(isequal(ds_pruned,ds));
        else
            assertEqual(ds_pruned,ds);
        end
    end



function helper_test_dim_prune(varargin)
    ds=cosmo_synthetic_dataset(varargin{:});

    has_pos=cosmo_match({'pos'},ds.a.fdim.labels);
    get_helper_handle=@(dim_arg,varargin)...
                @()...
                helper_test_dim_prune_dim(ds,dim_arg,varargin{:});

    dim_args={[],1,2,[1 2]};
    for k=1:numel(dim_args)
        dim_arg=dim_args{k};

        if has_pos && numel(dim_arg)>0 && any(dim_arg==2);
            assertExceptionThrown(get_helper_handle(...
                                    dim_arg),'');
            assertExceptionThrown(get_helper_handle(...
                                    dim_arg,...
                                    'matrix_labels',{'foo'}),'');
            helper_handle=get_helper_handle(...
                                    dim_arg,...
                                    'matrix_labels',{'pos'});
        else
            helper_handle=get_helper_handle(dim_arg);
        end
        helper_handle();
    end



function ds_pruned=helper_test_dim_prune_dim(ds_orig,prune_dim,varargin)

    dim=2;

    infixes='sf';
    infix=infixes(dim);

    dim_labels=ds_orig.a.([infix 'dim']).labels;

    % choose single dimension to prune, which must not be singleton
    n_dim=numel(dim_labels);

    for dim_to_prune=1:n_dim
        ds=ds_orig;
        % find attribute values in dimension to prune
        attr=ds.([infix 'a']).(dim_labels{dim_to_prune});
        unq=unique(attr);
        n=numel(unq);
        if n==1
            continue;
        end


        % set single value in dimension to 1, removing the presence
        % of remove_idx
        remove_idx=1+ceil(rand()*(n-1));
        orig_attr=attr;
        attr(attr==remove_idx)=1;
        ds.([infix 'a']).(dim_labels{dim_to_prune})=attr;


        ds_pruned=cosmo_dim_prune(ds,'dim',prune_dim,varargin{:});
        if all(prune_dim~=2)
            % nothing should have been pruned
            assertEqual(ds_pruned,ds);
        else
            assertEqual(ds.samples,ds_pruned.samples);
            for k=1:n_dim
                dim_value=ds.a.([infix 'dim']).values{k};

                if k==dim_to_prune
                    keep_indices=setdiff(1:n,remove_idx);
                    wanted_dim_pruned=dim_value(:,keep_indices);
                    % set expected .sa or .fa

                    wanted_attr_pruned=orig_attr;

                    equal_msk=orig_attr==remove_idx;
                    wanted_attr_pruned(equal_msk)=1;

                    after_msk=wanted_attr_pruned>remove_idx;
                    wanted_attr_pruned(after_msk)=wanted_attr_pruned(...
                                                            after_msk)-1;
                else
                    wanted_dim_pruned=dim_value;
                    wanted_attr_pruned=ds.([infix 'a']).(dim_labels{k});
                end

                dim_pruned=ds_pruned.a.([infix 'dim']).values{k};
                attr_pruned=ds_pruned.([infix 'a']).(dim_labels{k});

                assertEqual(dim_pruned,wanted_dim_pruned);
                assertEqual(attr_pruned,wanted_attr_pruned);

            end
        end
    end