test fx

function test_suite = test_fx
% tests for cosmo_fx
%
% #   For CoSMoMVPA's copyright information and license terms,   #
% #   see the COPYING file distributed with CoSMoMVPA.           #
    try % assignment of 'localfunctions' is necessary in Matlab >= 2016
        test_functions=localfunctions();
    catch % no problem; early Matlab versions can use initTestSuite fine
    end
    initTestSuite;

function test_fx_basics()
    % dataset with rows shuffled
    ds=cosmo_synthetic_dataset('nchunks',3,'ntargets',5);
    n_samples=size(ds.samples,1);

    fs={@(x)x, @abs, @(x)min(x,1), @(x)x(1:min(size(x,1),2),:)};
    split_by={{},{'chunks'},{'targets'},{'chunks','targets'}};

    n_fs=numel(fs);
    n_split_by=numel(split_by);

    for k=1:n_fs
        for j=1:n_split_by
            % shuffle rows
            rp=randperm(n_samples);
            ds_rp=cosmo_slice(ds,rp);

            helper_assert_fx_matches(ds_rp,fs{k},split_by{j});
        end
    end

function helper_assert_fx_matches(ds,f,split_by)
    res=cosmo_fx(ds,f,split_by);

    n_split_by=numel(split_by);
    if n_split_by==0
        idxs={1:size(ds.samples,1)};
    else
        unq_vals=cell(n_split_by,1);
        for k=1:n_split_by
            key=split_by{k};
            unq_vals{k}=ds.sa.(key);
        end

        idxs=cosmo_index_unique(unq_vals);
    end

    n_unq=numel(idxs);
    pos=0;

    for k=1:n_unq
        d=cosmo_slice(ds,idxs{k});
        s=f(d.samples);
        n_s=size(s,1);

        expected_res_part=cosmo_slice(d,ones(n_s,1));
        expected_res_part.samples=s;

        res_part=cosmo_slice(res,pos+(1:n_s));

        % do not care about .sa; but rest should match
        assertEqual(expected_res_part.samples,res_part.samples);
        assertEqual(expected_res_part.fa,res_part.fa);
        assertEqual(expected_res_part.a,res_part.a);

        pos=pos+n_s;
    end

function test_fx_unequal_output_size_other_dim
    ds=cosmo_synthetic_dataset('nchunks',3,'ntargets',5);
    ds.sa.targets(1)=2;

    f=@(x)ones(1,size(x,1));
    assertExceptionThrown(@()cosmo_fx(ds,f,{'targets'}),'');

function test_fx_feature_dim
    ds=cosmo_synthetic_dataset();
    fs={@(x)max(x,[],2),@(x)sum(x,2)};

    labels={'i','j','k'};

    for k=1:numel(fs)
        f=fs{k};
        for j=1:numel(labels)
            label=labels{j};
            res=cosmo_fx(ds,f,{label},2);
            assertEqual(res.sa,ds.sa);
            assertEqual(res.a,ds.a);

            % check samples
            fa_values=ds.fa.(label);
            unq_fa=unique(fa_values);

            for m=1:numel(unq_fa)
                s=ds.samples(:,fa_values==unq_fa(m));
                f_s=f(s);
                assertEqual(f_s,res.samples(:,m));
            end
        end
    end


function test_fx_illegal_input_arguments
    aet=@(varargin)assertExceptionThrown(@()...
                        cosmo_fx(varargin{:}),'');

    ds=cosmo_synthetic_dataset();
    aet(struct,@abs,{});
    aet(ds,[],{});
    aet(ds,@abs,{'i'});
    aet(ds,@abs,{'targets'},2);