test meeg io

function test_suite=test_meeg_io()
% tests for MEEG input/output
%
% #   For CoSMoMVPA's copyright information and license terms,   #
% #   see the COPYING file distributed with CoSMoMVPA.           #
    try % assignment of 'localfunctions' is necessary in Matlab >= 2016
        test_functions=localfunctions();
    catch % no problem; early Matlab versions can use initTestSuite fine
    end
    initTestSuite;


function test_meeg_ft_dataset()
    aet=@(varargin)assertExceptionThrown(@()...
                    cosmo_meeg_dataset(varargin{:}),'');

    dimords=get_ft_dimords();
    n=numel(dimords);
    for k=1:n
        dimord=dimords{k};
        [ft,fdim,data_label]=generate_ft_struct(dimord);
        ds=cosmo_meeg_dataset(ft);
        assertEqual(ds.a.fdim,fdim);

        [nsamples,nfeatures]=size(ds.samples);

        % check feature sizes
        fdim_sizes=cellfun(@numel,fdim.values);
        assertEqual(prod(fdim_sizes),nfeatures);

        % check sample size
        data=ft.(data_label);
        data_size=size(data);
        has_rpt=nsamples>1;
        if has_rpt
            assertEqual(data_size(2:end)',fdim_sizes);
        else
            assertEqual(data_size',fdim_sizes);
        end

        assertElementsAlmostEqual(data(:),ds.samples(:));

        ds2=cosmo_slice(ds,randperm(nfeatures),2);
        ft2=cosmo_map2meeg(ds2);

        if isfield(ft,'cfg')
            ft=rmfield(ft,'cfg');
        end

        if isfield(ft,'avg') && isfield(ft,'trial')
            ft=rmfield(ft,'avg');
        end

        assertEqual(ft,ft2);

        % wrong size trialinfo should not store trialinfo
        assertTrue(isfield(ds2.sa,'trialinfo'));
        ft.trialinfo=[1;2];
        ds3=cosmo_meeg_dataset(ft);
        assertFalse(isfield(ds3.sa,'trialinfo'));
    end

    aet(struct());
    aet(struct('avg',1));
    aet(struct('avg',1,'dimord','rpt_foo'));

function test_meeg_ft_dataset_trials()
    aet=@(varargin)assertExceptionThrown(@()...
                    cosmo_meeg_dataset(varargin{:}),'');

    dimords=get_ft_dimords();
    n=numel(dimords);
    for k=1:n
        dimord=dimords{k};
        ft=generate_ft_struct(dimord);
        ds=cosmo_meeg_dataset(ft);

        % check subset of trials option
        ntrials=size(ds.samples,1);
        trial_idx=ceil(rand(1,2)*ntrials);
        ds_single_trial=cosmo_meeg_dataset(ft,...
                                    'trials',trial_idx);
        assertEqual(cosmo_slice(ds,trial_idx),ds_single_trial);
        ds_single_trial=cosmo_meeg_dataset(ft,...
                                    cosmo_structjoin('trials',trial_idx));
        assertEqual(cosmo_slice(ds,trial_idx),ds_single_trial);

        illegal_args={ntrials+1,0,struct,cell(1,0),'foo',true,1.5};
        for j=1:numel(illegal_args)
            arg=illegal_args{j};
            aet(ft,'trials',arg);
        end
    end



function test_synthetic_meeg_dataset()
    combis=cosmo_cartprod({{'timelock','timefreq','source'},...
                            {'tiny','small','normal','big','huge'}});
    for k=1:4:size(combis,1)
        ds=cosmo_synthetic_dataset('type',combis{k,1},...
                                        'size',combis{k,2});

        ft=cosmo_map2meeg(ds);
        ds2=cosmo_meeg_dataset(ft);
        assertEqual(ds.samples,ds2.samples);
        assertEqual(ds.fa,ds2.fa);
        assertEqual(ds.a.meeg.samples_field,ds2.a.meeg.samples_field);
    end

    ds2=cosmo_meeg_dataset(ds,'targets',1);
    assertTrue(all(ds2.sa.targets==1));
    assertExceptionThrown(@()cosmo_meeg_dataset(ds,'targets',[1 2]),'');



function test_meeg_eeglab_txt_io()
    ds=cosmo_synthetic_dataset('type','meeg');

    tmp_fn=sprintf('_tmp_%06.0f.txt',rand()*1e5);
    file_remover=onCleanup(@()delete(tmp_fn));
    fid=fopen(tmp_fn,'w');
    file_closer=onCleanup(@()fclose(fid));

    chans=[{' '} ds.a.fdim.values{1}];
    fprintf(fid,'%s\t',chans{:});
    fprintf(fid,'\n');

    times=ds.a.fdim.values{2};
    ntime=numel(times);
    nsamples=size(ds.samples,1);

    for k=1:nsamples
        for j=1:ntime
            data=ds.samples(k,ds.fa.time==j);
            fprintf(fid,'%.3f',times(j));
            fprintf(fid,'\t%.4f',data);
            fprintf(fid,'\n');
        end
    end

    fclose(fid);
    ds2=cosmo_meeg_dataset(tmp_fn);

    assertElementsAlmostEqual(ds.samples,ds2.samples,'absolute',1e-4);
    assertEqual(ds.a.fdim.values{1},ds2.a.fdim.values{1});
    assertElementsAlmostEqual(ds.a.fdim.values{2},...
                                    1000*ds2.a.fdim.values{2});
    assertEqual(ds.a.fdim.labels,ds2.a.fdim.labels);
    assertEqual(ds.fa,ds2.fa);

    % test trials option
    nsamples=size(ds.samples,1);
    trial_idx=ceil(rand(1,2)*nsamples);
    ds_trials=cosmo_meeg_dataset(tmp_fn,'trials',trial_idx);
    ds_expected_trials=cosmo_slice(ds,trial_idx);
    assertElementsAlmostEqual(ds_trials.samples,...
                                ds_expected_trials.samples,...
                                'absolute',1e-4);

    % test illegal options
    aet=@(varargin)assertExceptionThrown(@()...
                            cosmo_meeg_dataset(varargin{:}),'');
    illegal_args={nsamples+1,0,struct,{},'foo',true,1.5};
    for j=1:numel(illegal_args)
        arg=illegal_args{j};
        aet(tmp_fn,'trials',arg);
        aet(tmp_fn,cosmo_structjoin('trials',arg));
    end

    % add bogus data, expect exception
    fid=fopen(tmp_fn,'a');
    fprintf(fid,'.3');
    fclose(fid);
    file_closer=[];

    aet(tmp_fn);

    tmp2_fn=sprintf('_tmp_%06.0f.txt',rand()*1e5);
    file_remover2=onCleanup(@()delete(tmp2_fn));
    cosmo_map2meeg(ds2,tmp2_fn);
    ds3=cosmo_meeg_dataset(tmp2_fn);
    assertEqual(ds2,ds3);


function test_meeg_ft_io()
    ds=cosmo_synthetic_dataset('type','meeg');
    tmp_fn=sprintf('_tmp_%06.0f.mat',rand()*1e5);
    file_remover=onCleanup(@()delete(tmp_fn));

    cosmo_map2meeg(ds,tmp_fn);
    ds2=cosmo_meeg_dataset(tmp_fn);
    assertEqual(ds.a.meeg.samples_field,ds2.a.meeg.samples_field);
    ds.a.meeg=[];
    ds.sa=struct();
    ds2.a.meeg=[];

    % deal with rounding errors in Octave
    assertElementsAlmostEqual(ds.samples,ds2.samples);
    assertElementsAlmostEqual(ds.a.fdim.values{2},ds2.a.fdim.values{2});

    ds3=cosmo_meeg_dataset(ds2);
    assertEqual(ds2,ds3);

    ds2.samples=ds.samples;
    ds2.a.fdim.values{2}=ds.a.fdim.values{2};
    assertEqual(ds,ds2);



function test_meeg_ft_io_exceptions()
    aeti=@(varargin)assertExceptionThrown(@()...
                    cosmo_meeg_dataset(varargin{:}),'');
    aeto=@(varargin)assertExceptionThrown(@()...
                    cosmo_map2meeg(varargin{:}),'');
    ds=cosmo_synthetic_dataset('type','timefreq');
    aeti('file_without_extension');
    aeti('file.with_unknown_extension');

    aeto(ds,'file_without_extension');
    aeto(ds,'file.with_unknown_extension');

    aeto(ds,'eeglab_timelock.txt'); % not supported


function dimords=get_ft_dimords()
    dimords={   'chan_time',...
                'rpt_chan_time'...
                'subj_chan_time'...
                'chan_freq',...
                'rpt_chan_freq',...
                'subj_chan_freq',...
                'chan_freq_time',...
                'rpt_chan_freq_time',...
                'subj_chan_freq_time',...
                };

function [ft,fdim,data_label]=generate_ft_struct(dimord)
    seed=1;

    fdim=struct();
    fdim.values=cell(3,1);
    fdim.labels=cell(3,1);

    ft=struct();
    ft.dimord=dimord;

    dims=cosmo_strsplit(dimord,'_');
    ndim=numel(dims);
    sizes=[3 4 5 6];

    chan_values={'MEG0113' 'MEG0112' 'MEG0111' 'MEG0122'...
                    'MEG0123' 'MEG0121' 'MEG0132'};
    freq_values=(2:2:24);
    time_values=(-1:.1:2);

    data_label='avg';
    ntrials=1;
    nkeep=0;

    for k=1:ndim
        idxs=1:sizes(k);
        switch dims{k}
            case 'rpt'
                data_label='trial';
                ntrials=numel(idxs);

            case 'subj'
                data_label='individual';
                ntrials=numel(idxs);

            case 'chan'
                ft.label=chan_values(idxs);
                nkeep=nkeep+1;
                fdim.values{nkeep}=ft.label;
                fdim.labels{nkeep}='chan';

            case 'freq'
                ft.freq=freq_values(idxs);
                data_label='powspctrm';
                nkeep=nkeep+1;
                fdim.values{nkeep}=ft.freq;
                fdim.labels{nkeep}='freq';

            case 'time'
                ft.time=time_values(idxs);
                nkeep=nkeep+1;
                fdim.values{nkeep}=ft.time;
                fdim.labels{nkeep}='time';

        end
    end

    fdim.values=fdim.values(1:nkeep);
    fdim.labels=fdim.labels(1:nkeep);

    keep_sizes=sizes(1:k);
    ft.(data_label)=cosmo_norminv(cosmo_rand(keep_sizes,'seed',seed));
    ft.cfg=struct();
    ft.trialinfo=[(1:ntrials);(ntrials:-1:1)]';

    if strcmp(data_label,'trial')
        ft.avg=mean(ft.(data_label),1);
    end


function test_eeglab_io()
    datatypes={'timef','erp','itc'};

    args=cosmo_cartprod({{true,false},...
                         {true,false},...
                         datatypes});


    ncombi=size(args,1);
    for k=1:ncombi
        arg=args(k,:);
        [s,ds,ext]=build_eeglab_dataset_struct(arg{:});

        ds_from_struct=cosmo_meeg_dataset(s);
        assertEqual(ds.samples,ds_from_struct.samples);
        assertEqual(ds,ds_from_struct);

        % store, then read using cosmo_meeg_dataset
        fn=sprintf('%s.%s',tempname(),ext);
        save(fn,'-mat','-struct','s');
        cleaner=onCleanup(@()delete(fn));

        ds_loaded=cosmo_meeg_dataset(fn);
        assertEqual(ds,ds_loaded);
        clear cleaner;

        s_converted=cosmo_map2meeg(ds,['-' ext]);
        assertEqual(s,s_converted)

        % store using cosmo_map2meeg, then read
        cosmo_map2meeg(ds,fn);
        cleaner=onCleanup(@()delete(fn));

        s_loaded=load(fn,'-mat');
        assertEqual(s_loaded,s);
        assertEqual(s,s_loaded);
        clear cleaner;
    end

function test_eeglab_io_trials()
% test with loading a subset of trials
    datatypes={'timef','erp','itc'};

    args=cosmo_cartprod({{true,false},...
                         {true,false},...
                         datatypes});

    ncombi=size(args,1);
    for k=1:ncombi
        arg=args(k,:);
        [s,ds,ext]=build_eeglab_dataset_struct(arg{:});

        nsamples=size(ds.samples,1);
        trial_idx=ceil(rand(1,2)*nsamples);
        ds_expected_trials=cosmo_slice(ds,trial_idx);

        % with struct input
        ds_trials=cosmo_meeg_dataset(s,'trials',trial_idx);
        assertElementsAlmostEqual(ds_trials.samples,...
                                ds_expected_trials.samples,...
                                'absolute',1e-4);

        % store, then read using cosmo_meeg_dataset
        fn=sprintf('%s.%s',tempname(),ext);
        save(fn,'-mat','-struct','s');
        cleaner=onCleanup(@()delete(fn));

        ds_loaded=cosmo_meeg_dataset(fn,'trials',trial_idx);
        assertEqual(ds_loaded,ds_expected_trials);


        % test illegal options
        aet=@(varargin)assertExceptionThrown(@()...
                            cosmo_meeg_dataset(varargin{:}),'');
        illegal_args={nsamples+1,0,struct,{},'foo',true,1.5};
        for j=1:numel(illegal_args)
            arg=illegal_args{j};
            aet(s,'trials',arg);
        end

        clear cleaner;
    end

function test_eeglab_io_ersp()
    aet=@(varargin)assertExceptionThrown(@()...
                    cosmo_meeg_dataset(varargin{:}),'');

    args=cosmo_cartprod({{true,false},...
                         {false},...
                         {'ersp'},...
                         {false,true}});

    ncombi=size(args,1);
    for k=1:ncombi
        arg=args(k,1:3);
        [s,ds_cell,ext]=build_eeglab_dataset_struct(arg{:});

        % either load baseline data or original data
        with_baseline=args{k,4};
        if with_baseline
            load_args={'data_field','erspbase'};
            ds=ds_cell{2};
        else
            load_args={'data_field','ersp'};
            ds=ds_cell{1};
        end

        % illegal without arguments or wrong arguments
        aet(s);
        aet(s,'data_field','foo');
        aet(s,'data_field',false);

        % load data with correct arguments
        ds_loaded=cosmo_meeg_dataset(s,load_args{:});
        assertEqual(ds_loaded,ds);


         % store, then read using cosmo_meeg_dataset
        fn=sprintf('%s.%s',tempname(),ext);
        save(fn,'-mat','-struct','s');
        cleaner=onCleanup(@()delete(fn));

        ds_loaded=cosmo_meeg_dataset(fn,load_args);

        % writing the file is not supported
        assertExceptionThrown(@()cosmo_map2meeg(ds,fn),'');

        clear cleaner

        assertEqual(ds_loaded,ds);
    end







function test_eeglab_io_exceptions()
    aet_md=@(varargin)assertExceptionThrown(@()...
                            cosmo_meeg_dataset(varargin{:}),'');
    aet_m2m=@(varargin)assertExceptionThrown(@()...
                            cosmo_map2meeg(varargin{:}),'');

    s=build_eeglab_dataset_struct(true,true,'timef');

    % bad datatype
    s.datatype='foo';
    aet_md(s)

    % output is not a filename
    ds=cosmo_synthetic_dataset('type','timefreq');
    aet_m2m(ds,struct);

    % bad  fdim
    good_labels={'chan','freq','time'};
    all_bad_labels={'chan','freq','time','foo'};

    for dim=1:numel(good_labels)
        for j=1:numel(all_bad_labels)
            ds_bad_chan_fdim=ds;
            bad=all_bad_labels{j};
            if ~strcmp(bad, good_labels{dim})
                ds_bad_chan_fdim.a.fdim.labels{dim}=bad;
                aet_m2m(ds_bad_chan_fdim,'-dattimef');
            end
        end
    end


function [s,ds,ext]=build_eeglab_dataset_struct(has_ica,has_trial,datatype,...
                        chan_dim,freq_dim,time_dim)
    if nargin<6
        time_dim=randint();
    end

    if nargin<5
        freq_dim=randint();
    end

    if nargin<4
        chan_dim=randint();
    end


    % trial dimension
    if has_trial
        trial_dim=randint();
    else
        trial_dim=1;
    end

    if strcmp(datatype,'ersp')
        % has baseline corrected data together with baseline data
        builder=@build_eeglab_dataset_struct;
        args={chan_dim,freq_dim,time_dim};
        [s1,ds1,ext]=builder(has_ica,has_trial,...
                                'ersp_baselinecorrected',args{:});
        [s2,ds2]=builder(has_ica,has_trial,...
                                'erspbase',args{:});

        keys=fieldnames(s1);
        for k=1:numel(keys)
            key=keys{k};
            s2.(key)=s1.(key);
        end

        s=s2;
        s.datatype=upper(datatype);

        % make sure parameters are the same
        ds1.a.meeg.parameters=s.parameters;
        ds2.a.meeg.parameters=s.parameters;

        if isfield(s,'chanlabels')
            chan_labels=s.chanlabels;
            ds1.a.fdim.values{1}=chan_labels;
            ds2.a.fdim.values{1}=chan_labels;
        end

        ds={ds1,ds2};
        % remove second part from extension
        ext=regexprep(ext,'_.*','');
        return;
    end

    % channel / component dimension
    if has_ica
        chan_prefix='comp';
        ext_prefix='ica';

        make_chan_prefix_func=@()chan_prefix;
    else
        chan_prefix='chan';
        ext_prefix='dat';

        make_chan_prefix_func=@randstr;
    end

    has_freq=2;

    switch datatype
        case 'timef'
            chan_suffix='_timef';

        case 'erp'
            chan_suffix='';

        case 'ersp_baselinecorrected'
            chan_suffix='_ersp';

        case 'erspbase'
            chan_suffix='_erspbase';

        case 'itc'
            chan_suffix='_itc';

        otherwise
            assert(false);
    end

    make_chan_label=@(idx) sprintf('%s%d',make_chan_prefix_func(),idx);

    chan_label={chan_prefix};
    chan_value={arrayfun(make_chan_label,1:chan_dim,...
                            'UniformOutput',false)};

    % frequency dimension
    switch datatype
        case {'timef','ersp_baselinecorrected','itc','erspbase'}
            has_freq=true;


        case {'erp'}
            has_freq=false;

        otherwise
            assert(false);
    end

    if has_freq
        freq_label={'freq'};
        freq_value={(1:freq_dim)*2};
        samples_type='timefreq';
    else
        freq_dim=[];
        freq_label={};
        freq_value={};
        samples_type='timelock';
    end

    ext_suffix=datatype;

    hastime=~strcmp(datatype,'erspbase');

    if hastime
        % include time dimension
        time_label={'time'};
        time_value={(1:time_dim)*.2-.1};
    else
        % no time dimension
        time_dim=[];
        time_label={};
        time_value={};
    end

    % data
    dim_sizes=[trial_dim,chan_dim,freq_dim,time_dim];
    dim_sizes_without_chan=[dim_sizes([1, 3:end]), 1];
    data_arr=randn(dim_sizes);

    % params
    parameters={randstr(), randstr()};

    % make dataset
    ds=cosmo_flatten(data_arr,...
                        [chan_label,freq_label,time_label],...
                        [chan_value,freq_value,time_value]);
    ds.sa=struct();
    ds.a.meeg.samples_field='trial';
    ds.a.meeg.samples_type=samples_type;
    ds.a.meeg.samples_label='rpt';
    ds.a.meeg.parameters=parameters;


    s=struct();
    for k=1:chan_dim
        key=sprintf('%s%d%s',chan_prefix,k,chan_suffix);
        value=data_arr(:,k,:);
        value_rs=reshape(value,dim_sizes_without_chan);

        if has_freq
            % it seems that for freq data, single trial data is the last
            % dimension, whereas for erp data, single trial data is the first
            % dimension.
            value_rs=shiftdim(value_rs,1);
        end

        s.(key)=value_rs;
    end

    if ~has_ica
        s.chanlabels=chan_value{1};
        assert(iscellstr(s.chanlabels));
    end

    if has_freq
        s.freqs=freq_value{1};
    end

    if hastime
        s.times=time_value{1};
    end

    s.datatype=upper(ext_suffix);
    s.parameters=parameters;

    ext=[ext_prefix, ext_suffix];


function test_dimord_label()
    opt=struct();
    opt.samples_label={'','rpt','trial'};
    opt.nsamples={1,randint(),10};
    opt.datatype={'timefreq','timelock'};

    combis=cosmo_cartprod(opt);
    n_combi=numel(combis);

    for k=1:n_combi
        c=combis{k};

        ds=cosmo_synthetic_dataset('type',c.datatype,...
                                        'ntargets',1,...
                                        'nchunks',c.nsamples,...
                                    'size','big');
        assertEqual(size(ds.samples,1),c.nsamples);

        with_samples_label=~isempty(c.samples_label);
        if with_samples_label
            ds.a.meeg.samples_label=c.samples_label;
        else
            ds.a.meeg=rmfield(ds.a.meeg,'samples_label');
        end

        data_is_average=c.nsamples==1 && ~with_samples_label;

        switch c.datatype
            case 'timefreq'
                samples_field='powspctrm';

            case 'timelock'
                if data_is_average
                    samples_field='avg';
                else
                    samples_field='trial';
                end

            otherwise
                assert(false)
        end


        ft=cosmo_map2meeg(ds);

        labels=cosmo_strsplit(ft.dimord,'_');

        ndim_expected=numel(ds.a.fdim.labels);
        if ~data_is_average
            ndim_expected=ndim_expected+1;
        end

        assertEqual(numel(labels),ndim_expected);
        assertEqual(numel(size(ft.(samples_field))),ndim_expected);

    end


function test_meeg_source_dataset_pos_dim_inside_fields()
% mapping back and forth should be fine whether or not the
% 'pos', 'dim' and 'inside' fields are there or not
    ds_orig=cosmo_synthetic_dataset('type','source','size','huge');
    ft_orig=cosmo_map2meeg(ds_orig);

    for has_dim=[false,true]
        for has_tri=[false,true]
            for has_inside=[false,true]
                ft=ft_orig;

                if has_dim
                    ft.dim=[2 2 2];
                end

                n_pos=size(ft.pos,1);

                if has_tri
                    ft.tri=ceil(rand(5,3)*max(n_pos));
                end

                if ~has_inside
                    ft=rmfield(ft,'inside');
                end

                func=@()cosmo_meeg_dataset(ft);

                % verify fields are there
                ds=func();
                assertEqual(has_dim,isfield(ds.a.meeg,'dim'));
                assertEqual(has_tri,isfield(ds.a.meeg,'tri'));

                % map back
                ft_back=cosmo_map2meeg(ds);

                % cosmo_map2meeg always returns an inside field, so for
                % now we remove it if it is present here
                if ~has_inside
                    ft_back=rmfield(ft_back,'inside');
                end

                % ensure expected fields are preserved
                assertEqual(ft,ft_back);
            end
        end
    end


function x=randint()
    x=ceil(rand()*10+5);

function x=randstr()
    x=char(rand(1,10)*24+65);