test meeg source

function test_suite = test_meeg_source()
    % tests for MEEG datasets in source space
    %
    % #   For CoSMoMVPA's copyright information and license terms,   #
    % #   see the COPYING file distributed with CoSMoMVPA.           #
    try % assignment of 'localfunctions' is necessary in Matlab >= 2016
        test_functions = localfunctions();
    catch % no problem; early Matlab versions can use initTestSuite fine
    end
    initTestSuite;

function test_meeg_dataset()
    tps = {'freq_pow', 'time_mom', 'rpt_trial_mom', 'rpt_trial_mom_1d', ...
           'time_mom_1d', ...
           'rpt_trial_pow', 'strange_ft_lcmv_pow'};
    for j = 1:numel(tps)
        is_ft_strange_lcmv_avg_pow = isequal(cosmo_strsplit(tps{j}, '_', 1), ...
                                             'strange');

        [ft, fdim, data_label] = generate_ft_source(tps{j});
        ds = cosmo_meeg_dataset(ft);

        % check fdim
        assertEqual(ds.a.fdim, fdim);

        key = data_label{1};
        sub_key = data_label{2};

        ft_first_sample_all = ft.(key).(sub_key);

        if is_ft_strange_lcmv_avg_pow
            ft_first_sample = ft_first_sample_all(:, ft.inside);
        else
            ft_first_sample = ft_first_sample_all(ft.inside, :);
        end

        switch sub_key
            case 'mom'
                % choose a random sensor
                inside_idxs = find(ft.inside);
                pos = ceil(rand() * numel(inside_idxs));
                inside_idx = inside_idxs(pos);
                assertEqual(ds.samples(1, ds.fa.pos == inside_idx), ...
                            ft_first_sample{pos}(:)');

            otherwise
                assertEqual(ds.samples(1, :), ft_first_sample(:)');
        end

        [ft_arr, ft_labels, ft_values] = cosmo_unflatten(ds, 2, ...
                                                         'matrix_labels', 'pos');

        if is_ft_strange_lcmv_avg_pow
            ft_arr = ft_arr';
        end
        assertEqual(ft_arr(ft_arr ~= 0), ds.samples(:));
        assertEqual(ft_labels, fdim.labels);
        assertEqual(ft_values, fdim.values);

        % select single element, and ensure it is the same in the
        % fieldtrip struct as in the dataset struct
        dim_sizes = cellfun(@(x)size(x, 2), fdim.values);
        ndim = numel(dim_sizes);
        rp = ceil(rand(1, ndim) .* dim_sizes(:)');
        [nsamples, nfeatures] = size(ds.samples);
        ds_msk = false(1, nfeatures);
        ft_idx = cell(1, 1 + ndim);
        ft_idx{1} = randperm(nsamples);
        for k = 1:ndim
            dim_label = fdim.labels{k};
            ds_msk = ds_msk | rp(k) ~= ds.fa.(dim_label);

            switch dim_label
                case 'mom'
                    ft_idx{k + 1} = rp(k);

                case 'pos'
                    ft_values = ft.(dim_label);
                    ft_idx{k + 1} = find(all(bsxfun(@eq, ...
                                                    ft_values(rp(k), :), ft_values), 2));

                otherwise
                    ft_values = ft.(dim_label);
                    ft_idx{k + 1} = find(ft_values(rp(k)) == ft_values);
            end
        end

        ds_sel = cosmo_slice(cosmo_slice(ds, ~ds_msk, 2), ft_idx{1});

        if is_ft_strange_lcmv_avg_pow
            ft_idx = ft_idx(2:end);
        end

        ft_sel = ft_arr(ft_idx{:});

        if ft_sel == 0
            assertTrue(isempty(ds_sel.samples));
        else
            assertEqual(ds_sel.samples, ft_sel);
        end

        % re-order features
        nfeatures = size(ds.samples, 2);
        ds2 = cosmo_slice(ds, randperm(nfeatures), 2);

        ft2 = cosmo_map2meeg(ds2);

        if is_ft_strange_lcmv_avg_pow
            ft2.avg.pow = ft2.avg.pow';
            ft2.time = ft.time;
        end

        assertEqual(ft, ft2);

        if j == 1
            % test compatibility with old fieldtrip

            ft2.inside = find(ft2.inside);
            ds3 = cosmo_meeg_dataset(ft2);
            assertEqual(ds, ds3);

            ft2.inside = struct();
            assertExceptionThrown(@()cosmo_meeg_dataset(ft2), '');
        end
    end

function test_meeg_fmri_dataset()
    ds = cosmo_synthetic_dataset('type', 'source');
    ds_fmri = cosmo_fmri_dataset(ds);
    ft = cosmo_map2meeg(ds);
    ds_ft_fmri = cosmo_fmri_dataset(ft);

    ds_vol = cosmo_vol_grid_convert(ds, 'tovol');
    assertEqual(ds_vol, ds_fmri);

    assertTrue(isempty(fieldnames(ds_ft_fmri.sa)));
    ds_vol = rmfield(ds_vol, 'sa');
    ds_ft_fmri = rmfield(ds_ft_fmri, 'sa');
    assertEqual(ds_vol, ds_ft_fmri);

function test_irregular_source_grid()
    ft = generate_ft_source('freq_pow');
    ft.pos = ft.pos + randn(size(ft.pos));

    ds = cosmo_meeg_dataset(ft);
    ft2 = cosmo_map2meeg(ds);

    assertEqual(ft, ft2);

function [ft, fdim, data_label] = generate_ft_source(tp)
    ft = struct();
    dim_pos_range = {-3:3, -4:4, -5:5};
    nsamples = 2;
    freq = [3 5 7 9];
    time = [-1 0 1 2];
    mom_labels_3d = {'x', 'y', 'z'};
    mom_labels_1d = {'xyz'};

    ft.dim = cellfun(@numel, dim_pos_range);
    ft.pos = cosmo_cartprod(dim_pos_range);
    ft.inside = sum(ft.pos.^2, 2) < 30;

    fdim = struct();

    switch tp
        case 'freq_pow'
            ft.freq = freq;
            ft.method = 'average';
            ft.avg = generate_data(ft.inside, numel(freq), 1, 'pow');
            fdim.labels = {'pos'; 'freq'};
            fdim.values = {ft.pos'; freq(:)'};
            data_label = {'avg', 'pow'};

        case 'time_mom'
            ft.time = time;
            ft.method = 'average';
            ft.avg = generate_data(ft.inside, numel(time), 1, 'mom3d');
            fdim.labels = {'pos'; 'mom'; 'time'};
            fdim.values = {ft.pos'; mom_labels_3d; time(:)'};
            data_label = {'avg', 'mom'};

        case 'time_mom_1d'
            ft.time = time;
            ft.method = 'average';
            ft.avg = generate_data(ft.inside, numel(time), 1, 'mom1d');
            fdim.labels = {'pos'; 'mom'; 'time'};
            fdim.values = {ft.pos'; mom_labels_1d; time(:)'};
            data_label = {'avg', 'mom'};

        case 'rpt_trial_pow'
            ft.time = time;
            ft.method = 'rawtrial';
            ft.trial = generate_data(ft.inside, numel(time), nsamples, 'pow');
            fdim.labels = {'pos'; 'time'};
            fdim.values = {ft.pos'; time(:)'};
            data_label = {'trial', 'pow'};

        case 'rpt_trial_mom'
            ft.time = time;
            ft.method = 'rawtrial';
            ft.trial = generate_data(ft.inside, numel(time), nsamples, 'mom3d');
            fdim.labels = {'pos'; 'mom'; 'time'};
            fdim.values = {ft.pos'; mom_labels_3d; time(:)'};
            data_label = {'trial', 'mom'};

        case 'rpt_trial_mom_1d'
            ft.time = time;
            ft.method = 'rawtrial';
            ft.trial = generate_data(ft.inside, numel(time), nsamples, 'mom1d');
            fdim.labels = {'pos'; 'mom'; 'time'};
            fdim.values = {ft.pos'; mom_labels_1d; time(:)'};
            data_label = {'trial', 'mom'};

        case 'strange_ft_lcmv_pow'
            % FT with LCMV source puts .avg data in a 1xNSOURCE matrix,
            % even with NTIME>1 timepoints. In this case, the output
            % dataset structure has the time points averaged.
            ft.time = time;
            ft.method = 'average';
            ft.avg = generate_data(ft.inside, 1, 1, 'pow');
            ft.avg.pow = ft.avg.pow';
            fdim.labels = {'pos'; 'time'};
            fdim.values = {ft.pos'; mean(time)};
            data_label = {'avg', 'pow'};

        otherwise
            error('unsupported type %s', tp);

    end

function d = generate_data(inside, nfeatures, nsamples, fld)
    is_single_trial = nargin < 3;

    if is_single_trial
        nsamples = 1;
    end

    switch fld
        case 'mom3d'
            d = generate_mom(inside, nfeatures, nsamples, 3);
        case 'mom1d'
            d = generate_mom(inside, nfeatures, nsamples, 1);
        case 'pow'
            d = generate_pow(inside, nfeatures, nsamples);
        otherwise
            error('not supported: %s', fld);
    end

function all_trials = generate_pow(inside, nfeatures, nsamples)
    nf = numel(inside);
    ni = sum(inside);

    all_data = NaN(nf, nfeatures);
    trial_data = cosmo_rand(ni, nfeatures);

    all_trials_cell = cell(nsamples, 1);
    for j = 1:nsamples
        data = all_data;
        data(inside, :) = trial_data + j;
        all_trials_cell{j}.pow = data;
    end

    all_trials = cat(1, all_trials_cell{:})';

function all_trials = generate_mom(inside, nfeatures, nsamples, ndim)
    nf = numel(inside);
    i = find(inside);
    ni = numel(i);

    data = cosmo_rand(ndim, nfeatures, ni);

    all_trials_cell = cell(1, nsamples);
    for j = 1:nsamples
        one_trial = cell(nf, 1);
        for k = 1:ni
            one_trial{i(k)} = data(:, :, k) + j;
        end
        all_trials_cell{j} = one_trial;
    end

    all_trials = struct('mom', all_trials_cell);