test flatten

function test_suite=test_flatten()
% tests for cosmo_flatten
%
% #   For CoSMoMVPA's copyright information and license terms,   #
% #   see the COPYING file distributed with CoSMoMVPA.           #
    try % assignment of 'localfunctions' is necessary in Matlab >= 2016
        test_functions=localfunctions();
    catch % no problem; early Matlab versions can use initTestSuite fine
    end
    initTestSuite;

function test_flatten_and_unflatten()
    combis=cosmo_cartprod({1:2,0:2,1:2});
    for k=1:size(combis,1)
        args=num2cell(combis(k,:));
        run_helper_test_flatten(args{:});
    end

function run_helper_test_flatten(nsamples, nfdim, dim)
    aet_fl=@(varargin) assertExceptionThrown(@()...
                                cosmo_flatten(varargin{:}),'');
    aet_unfl=@(varargin) assertExceptionThrown(@()...
                                cosmo_unflatten(varargin{:}),'');

    ndata=nsamples*30;
    orig_labels={'i','j','k'};
    orig_values={[1:2;3 4],[1:3;3:-1:1],{'a','b','c','d','e'}};

    use_vector_values=nfdim<=1;
    transpose_vectors=nfdim==0;

    if use_vector_values
        % select first row only
        orig_values=cellfun(@(x)x(1,:),orig_values,...
                                        'UniformOutput',false);
    end

    transpose_count=transpose_vectors+0;

    switch dim
        case 1
            data_shape=[2 3 5 nsamples];

            transpose_count=transpose_count+1;
            a_dim='sa';
            attr_dim='sdim';

        case 2
            data_shape=[nsamples 2 3 5];

            a_dim='fa';
            attr_dim='fdim';

    end

    if use_vector_values && transpose_vectors
        opt=struct();
        wrong_opt=struct();
        wrong_opt.matrix_labels={'i','j'};
    elseif use_vector_values
        opt=struct();
        wrong_opt='this will raise an error because it is not a struct';
    else
        opt=struct();
        opt.matrix_labels={'i','j'};
        wrong_opt=struct();
    end

    if mod(dim,2)==1
        tr=@transpose;
    else
        tr=@(x)x;
    end

    if mod(transpose_count,2)==1
        tr_values=@transpose;
    else
        tr_values=@(x)x;
    end

    orig_values_tr=cellfun(tr_values,orig_values,'UniformOutput',false);

    data=reshape(1:ndata,data_shape);

    ds=cosmo_flatten(data,orig_labels,orig_values_tr,dim,opt);
    aet_fl(data,orig_labels,orig_values_tr,dim,wrong_opt);

    assertEqual(ds.samples(:),(1:ndata)');
    assertEqual(ds.(a_dim).i,tr(repmat([1 2],1,15)));
    assertEqual(ds.(a_dim).j,tr(repmat([1 1 2 2 3 3],1,5)));
    assertEqual(ds.(a_dim).k,tr(kron(1:5,ones(1,6))));
    aet_fl(data,orig_labels,cellfun(@(x)x(1:2),orig_values,...
                                'UniformOutput',false));

    expected_values=cellfun(tr,orig_values,'UniformOutput',false);
    expected_labels=cellfun(tr,orig_labels,'UniformOutput',false);

    assertEqual(ds.a.(attr_dim).values,expected_values);
    assertEqual(ds.a.(attr_dim).labels,expected_labels);

    % test unflatten
    if transpose_vectors
        ds.a.(attr_dim).values=cellfun(@transpose,...
                                    ds.a.(attr_dim).values,...
                                    'UniformOutput',false);
    end

    [data2,labels,values]=cosmo_unflatten(ds,dim,opt);
    aet_unfl(ds,dim,wrong_opt);

    assertEqual(data,data2);
    assertEqual(labels,expected_labels);
    assertEqual(values,expected_values);
    aet_unfl(ds,3-dim);



    % test exceptions
    aet=@(varargin)assertExceptionThrown(@()...
                        cosmo_unflatten(varargin{:}),'');
    ds2=cosmo_stack({ds,ds},dim);
    aet(ds2,dim);

    ds_bad=ds;
    ds_bad.a.(attr_dim).values=ds_bad.a.(attr_dim).values(1:(end-1));
    aet(ds_bad,dim,opt);

    % illegal dim argument
    aet(ds,3);
    aet(ds,[1 1]*dim,opt);

function test_unflatten_exceptions()
    aet=@(varargin)assertExceptionThrown(@()...
                        cosmo_unflatten(varargin{:}),'');
    ds=cosmo_synthetic_dataset();
    cosmo_unflatten(ds); % should be ok

    bad_ds=ds;
    bad_ds.foo=2;
    aet(bad_ds);