test dim match

function test_suite=test_dim_match
% tests for cosmo_dim_match
%
% #   For CoSMoMVPA's copyright information and license terms,   #
% #   see the COPYING file distributed with CoSMoMVPA.           #
    try % assignment of 'localfunctions' is necessary in Matlab >= 2016
        test_functions=localfunctions();
    catch % no problem; early Matlab versions can use initTestSuite fine
    end
    initTestSuite;

function test_dim_match_basics()
    ds=get_dim_dataset();

    [nsamples,nfeatures]=size(ds.samples);
    ds=cosmo_slice(ds,randperm(nsamples));
    ds=cosmo_slice(ds,randperm(nfeatures),2);

    for dim=1:2
        for subset=1:3
            if subset==3
                sel=[1 2];
            else
                sel=subset;
            end

            verify_with(ds, dim, sel);
        end
    end

function verify_with(ds, dim, sel)
    verify_with_dataset_or_neighborhood(ds,dim,sel);

    nh=ds;

    delete_fields={'f','s'};
    nh=rmfield(nh,[delete_fields{dim} 'a']);
    nh.a=rmfield(nh.a,[delete_fields{dim} 'dim']);

    samples_size=size(ds.samples);
    nh=rmfield(nh,'samples');

    nh.neighbors=cell(samples_size(dim),1);
    nh.origin.a=nh.a;



    verify_with_dataset_or_neighborhood(nh, dim, sel);



function verify_with_dataset_or_neighborhood(ds, dim, sel)
    infixes='sf';
    infix=infixes(dim);

    % get sdim or fdim
    xdim=ds.a.([infix 'dim']);
    labels=xdim.labels;
    values=xdim.values;

    % get fa or sa
    xa=ds.([infix 'a']);

    if isfield(ds,'samples')
        sizes=size(ds.samples);
        nx=sizes(dim);
    else
        nx=numel(ds.neighbors);
    end

    msk=true(nx,1);

    n_sel=numel(sel);
    xdim_labels=cell(n_sel,1);
    xa_values=cell(n_sel,1);
    all_idxs=cell(n_sel,1);
    sel_idxs=cell(n_sel,1);
    sel_fhandles=cell(n_sel,1);



    for k=1:n_sel
        s=sel(k);

        value=values{s};
        label=labels{s};

        xdim_labels{k}=label;

        % select a subset
        nx=numel(value);
        rp=randperm(nx);
        xa_idx=rp(1:ceil(rand()*nx));
        xa_values{k}=value(xa_idx);

        m=cosmo_match(xa.(label),xa_idx);

        all_idxs{k}=xa.(label);
        sel_idxs{k}=xa_idx;
        sel_fhandles{k}=@(x)cosmo_match(x,xa_values{k});

        msk=msk & m(:);
    end

    if dim==2
        msk=reshape(msk,1,[]);
    end

    % labels and values
    args=[xdim_labels(:)';xa_values(:)'];
    dm=cosmo_dim_match(ds,args{:});
    assertEqual(dm,msk);

    dm2=cosmo_dim_match(ds,args{:},dim);
    assertEqual(dm2,msk);

    % labels and function handles
    args=[xdim_labels(:)';sel_fhandles(:)'];
    dm=cosmo_dim_match(ds,args{:});
    assertEqual(dm,msk);

    args=[xdim_labels(:)';sel_fhandles(:)'];
    dm=cosmo_dim_match(ds,args{:},dim);
    assertEqual(dm,msk);




function test_exceptions()
    aet=@(varargin)assertExceptionThrown(@()...
                    cosmo_dim_match(varargin{:}),'');
    aet({},'i',1);
    ds=struct();
    aet(ds,'i',1);

    % input mismatch
    ds=get_dim_dataset();
    aet(ds,'i','a');
    aet(ds,'targets',2);
    aet(ds,'chunks','2');

    % cannot mask in both dimensions
    aet(ds,'chunks',1,'i',4);

    % size mismatch
    aet(ds,[1 2],1);
    aet(ds,'chunks',1,1:10,2);

    % wrong dimension
    aet(ds,'i',1,1);
    aet(ds,'foo',1);
    aet(ds,'foo',1,2);


    % no vector input
    aet(ds,'i',zeros(2));

    % illegal cell input
    aet(ds,'i',{1});

    % no mixing of dimensions
    ds=cosmo_synthetic_dataset(); % square
    ds.a.sdim.labels={'targets','chunks'};
    ds.a.sdim.values={1:3,4:6};
    assertEqual(size(ds.samples),[6 6])
    aet(ds,'i',1,'targets',1);



function ds=get_dim_dataset()
    % return dataset with values in sample and feature dimensions
    ds=cosmo_synthetic_dataset('size','normal','nchunks',3,'ntargets',4);
    ds.a.sdim.labels={'targets','chunks'};
    ds.a.sdim.values={{'foo','bar','baz','bazz'},4:6};