cosmo map2meeg skl

function hdr = cosmo_map2meeg(ds, fn)
    % maps a dataset to a FieldTrip or EEGlab structure or file
    %
    % hdr=cosmo_map2meeg(ds[, fn])
    %
    % Inputs:
    %    ds               dataset struct with field .samples with MEEG data
    %    fn               output filename or extension. If a filename,
    %                     the following extensions are supported:
    %                       .mat :        FieldTrip time-locked or
    %                                     time-frequency  data at  either the
    %                                     sensor or source level.
    %                       .txt :        exported EEGLab with timelocked data.
    %                       .daterp       time-locked               }
    %                       .icaerp       ICA time-locked           } EEGLab
    %                       .dattimef     time-freq                 }
    %                       .icatimef     ICA time-freq             }
    %                       .datitc       inter-trial coherence     }
    %                       .icaitc       ICA inter-trial coherence }
    %                     To avoid writing a file, but get output in the hdr
    %                     field, use one of the extensions above but with the
    %                     dot ('.') replaced by a hyphen ('-'), for example
    %                     '-dattimef' for time-freq data.
    %
    % Returns:
    %    hdr        FieldTrip or EEGLAB struct with the MEEG data
    %
    % Notes:
    %    - a typical use case is to use this function to map the dataset to a
    %      FieldTrip struct, then use FieldTrip to visualize the data
    %    - there is currently no support for writing EEGLAB 'ersp' data.
    %
    % Examples:
    %     % convert a dataset struct to a FieldTrip struct
    %     ft=cosmo_map2meeg(ds);
    %
    %     % store a dataset in FieldTrip file
    %     cosmo_map2meeg(ds,'fieldtrip_data.mat');
    %
    %     % store a timeseries dataset in an EEGlab text file
    %     cosmo_map2meeg(ds,'eeglab_data.txt');
    %
    %     % convert a dataset structure to a FieldTrip structure
    %     ft=cosmo_map2meeg(ds,'-mat');
    %
    %     % convert a time-lock dataset to an EEGLAB structure
    %     eeglab_daterp=cosmo_map2meg(ds,'-daterp');
    %
    %     % write EEGLAB time-frequency data
    %     cosmo_map2meeg(ds,'timefreq.dattimef');
    %
    % #   For CoSMoMVPA's copyright information and license terms,   #
    % #   see the COPYING file distributed with CoSMoMVPA.           #

    if nargin < 2
        fn = '-mat';
    end

    cosmo_check_dataset(ds, 'meeg');

    % for now only support ft-like output
    [ext, img_format, write_to_file] = find_img_format(fn);

    builder = img_format.builder;
    hdr = builder(ds, ext);

    % if filename was provided, store to file
    if write_to_file
        writer = img_format.writer;
        % write the file
        writer(fn, hdr);
    end

function [ext, img_format, write_to_file] = find_img_format(fn)
    if ~ischar(fn) || isempty(fn)
        error('filename must be non-empty string');
    end

    write_to_file = fn(1) ~= '-';
    if write_to_file
        ext = get_filename_extension(fn);
    else
        ext = fn(2:end);
    end

    all_formats = get_all_supported_img_formats();
    keys = fieldnames(all_formats);

    idx = find(cellfun(@(x)cosmo_match({ext}, all_formats.(x).exts), keys));
    n_match = numel(idx);
    assert(n_match <= 1); % cannot have multiple matches

    if n_match == 0
        error('Image format not found for extension ''%s''', ext);
    end

    img_format = all_formats.(keys{idx});

function ext = get_filename_extension(fn)
    fn_parts = cosmo_strsplit(fn, '.');
    if numel(fn_parts) < 2
        error('Filename needs extension');
    end
    ext = fn_parts{end};

function all_formats = get_all_supported_img_formats()
    all_formats = struct();

    % EEGLAB text
    all_formats.eeglab_txt.exts = {'txt'};
    all_formats.eeglab_txt.builder = @build_ft;
    all_formats.eeglab_txt.writer = @write_eeglab_txt;

    % EEGLAB matlab
    all_formats.eeglab.exts = {  'daterp', ...
                               'icaerp', ...
                               'dattimef', ...
                               'icatimef', ...
                               'datitc', ...
                               'icaitc'};
    all_formats.eeglab.builder = @build_eeglab;
    all_formats.eeglab.writer = @write_struct_as_mat;

    all_formats.ft.exts = {'mat'};
    all_formats.ft.builder = @build_ft;
    all_formats.ft.writer = @write_struct_as_mat;

function write_struct_as_mat(fn, hdr)
    % use matlab save
    save(fn, '-mat', '-struct', 'hdr');

function tf = choose_equal_or_exception(value, if_true, if_false, desc)
    if isequal(value, if_true)
        tf = true;
    elseif isequal(value, if_false)
        tf = false;
    else
        error('value for %s is not supported', desc);
    end

    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    % EEGLAB text
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function write_eeglab_txt(fn, hdr)
    if ~is_ft_timelock(hdr)
        error('Only time-lock data is supported for EEGlab data');
    end

    % prepare header
    header = [cosmo_strjoin([{' '}, hdr.label(:)', {''}], '\t') '\n'];

    % prepare body
    data = hdr.trial;
    [ntrial, nchan, ntime] = size(data);
    arr = zeros(ntrial * ntime, nchan + 1);

    % set time dimension - and convert seconds to milliseconds
    arr(:, 1) = repmat(hdr.time(:) * 1000, ntrial, 1);
    arr(:, 2:end) = reshape(shiftdim(data, 2), ntime * ntrial, nchan);

    % prepare pattern to write data in array
    arr_pat = [cosmo_strjoin(repmat({'%.4f'}, 1, nchan + 1), '\t') '\n'];

    % write data
    fid = fopen(fn, 'w');
    fprintf(fid, header);
    fprintf(fid, arr_pat, arr'); % transpose because order is row then column
    fclose(fid);

    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    % EEGLAB
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function s = build_eeglab(ds, ext)
    fdim = ds.a.fdim;
    fdim_labels = fdim.labels(:);
    fdim_values = fdim.values(:);

    has_ica = choose_equal_or_exception(fdim_labels{1}, 'comp', 'chan', ...
                                        'fdim channel label');
    has_freq = choose_equal_or_exception(fdim_labels(2:end), ...
                                         {'freq'; 'time'}, {'time'}, ...
                                         'fdim dimension labels');
    if has_ica
        chan_prefix = 'comp';
    else
        chan_prefix = 'chan';
    end

    % preparet output
    s = struct();

    % set frequency, if present
    if has_freq
        s.freqs = fdim_values{2};
        freq_sz = numel(s.freqs);

        datatype_candidates = {'timef', 'itc'};
        msk = cellfun(@(x)contains_string(ext, x), datatype_candidates);
        idx = find(msk);

        assert(numel(idx) == 1, 'this should not happen'); % weird data
        datatype = datatype_candidates{idx};

        chan_suffix = sprintf('_%s', datatype);
    else
        freq_sz = [];
        datatype = 'erp';
        chan_suffix = '';
    end

    % set datatype
    s.datatype = upper(datatype);

    % deal with feature dimensions
    nsamples = size(ds.samples, 1);
    ntime = numel(fdim_values{end});
    each_chan_sz = [nsamples, freq_sz, ntime]; % with or without freq

    chan_names = fdim_values{1};
    nchan = numel(chan_names);

    % unflattten the array
    arr = cosmo_unflatten(ds, 2);
    assert(nchan == size(arr, 2));

    for k = 1:nchan
        chan_arr = reshape(arr(:, k, :), each_chan_sz);
        key = sprintf('%s%d%s', chan_prefix, k, chan_suffix);

        % it seems that for freq data, single trial data is the last
        % dimension, whereas for erp data, single trial data is the first
        % dimension.
        if has_freq
            chan_arr = shiftdim(chan_arr, 1);
            % note: no shift for erp data, as time is already the first
            % dimension
        end
        s.(key) = chan_arr;
    end

    if ~has_ica
        s.chanlabels = chan_names;
    end

    s.times = fdim_values{end};
    s = set_parameters_if_present(s, ds);

function tf = contains_string(haystack, needle)
    tf = ~isempty(strfind(haystack, needle));

function s = set_parameters_if_present(s, ds)
    if cosmo_isfield(ds, 'a.meeg.parameters')
        s.parameters = ds.a.meeg.parameters;
    end

    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    % FieldTrip
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

function tf = is_ft_timelock(ft)
    tf = isstruct(ft) && ...
            isfield(ft, 'dimord') && ...
            cosmo_match({ft.dimord}, {'rpt_chan_time', ...
                                      'subj_chan_time', ...
                                      'chan_time'});

function samples_field = ft_detect_samples_field(ds, is_without_samples_dim)
    nfreq = sum(cosmo_match(ds.a.fdim.labels, {'freq'}));
    ntime = sum(cosmo_match(ds.a.fdim.labels, {'time'}));
    nchan = sum(cosmo_match(ds.a.fdim.labels, {'chan'}));

    has_samples_field = cosmo_isfield(ds, 'a.meeg.samples_field');

    if is_ds_source_struct(ds)
        if is_without_samples_dim
            main_field = 'avg';
        else
            main_field = 'trial';
        end

        if cosmo_match({'mom'}, ds.a.fdim.labels)
            sub_field = 'mom';
        else
            sub_field = 'pow';
        end

        samples_field = sprintf('%s.%s', main_field, sub_field);

        return
    else
        if nchan >= 1 && ntime >= 1
            if nfreq >= 1
                % time-freq data
                samples_field = 'powspctrm';
                return
            end

            if is_without_samples_dim
                % time-locked, single sample
                samples_field = 'avg';
                return
            end

            if ~has_samples_field
                % time-locked, multiple trials
                samples_field = 'trial';
                return
            end
        end
    end

    % fallback option
    samples_field = ds.a.meeg.samples_field;

function tf = is_ds_source_struct(ds)
    tf = isfield(ds, 'fa') && isfield(ds.fa, 'pos') && ...
                cosmo_isfield(ds, 'a.fdim.labels') && ...
                cosmo_match({'pos'}, ds.a.fdim.labels);

function [ft, samples_label, dim_labels] = get_ft_samples(ds)
    [arr, dim_labels] = cosmo_unflatten(ds, [], 'set_missing_to', NaN, ...
                                        'matrix_labels', {'pos'});

    if cosmo_isfield(ds, 'a.meeg.samples_label')
        samples_label = {ds.a.meeg.samples_label};
    else
        if size(ds.samples, 1) == 1
            samples_label = cell(0);
        else
            samples_label = {'rpt'};
        end
    end

    is_without_samples_dim = isempty(samples_label);

    % store the data
    samples_field = ft_detect_samples_field(ds, is_without_samples_dim);

    samples_field_keys = cosmo_strsplit(samples_field, '.');
    nsubfields = numel(samples_field_keys) - 1;

    if xor(nsubfields > 0, is_ds_source_struct(ds))
        error(['Found sample field %s, which is incompatible '...
               'with the dataset being in source space or not. '...
               'This is not supported'], samples_field);
    end

    switch nsubfields
        case 0
            % non-source data
            ft = get_ft_sensor_samples_from_array(arr, samples_label, ...
                                                  samples_field_keys{1});

        case 1
            % source data
            ft = get_ft_source_samples_from_array(ds, arr, ...
                                                  samples_field_keys{1}, ...
                                                  samples_field_keys{2});

        otherwise
            error(['Found sample field %s with more than one '...
                   'subfield %s'], samples_field);
    end

function ft = get_ft_sensor_samples_from_array(arr, samples_label, key)
    if isempty(samples_label)
        arr_size = size(arr);
        size_at_least_2d = [arr_size(2:end) 1];
        arr_ft = reshape(arr, size_at_least_2d);
    else
        arr_ft = arr;
    end

    ft = struct();
    ft.(key) = arr_ft;

function ft = get_ft_source_samples_from_array(ds, arr, key, sub_key)
    ft = init_ft_source_fields(ds);

    arr_size = size(arr);
    nsamples = arr_size(1);
    remainder_size = arr_size(2:end);

    switch sub_key
        case 'pow'
            converter = @convert_ft_source_vector2array;

        case 'mom'
            converter = @convert_ft_source_vector2cell;

        otherwise
            error('unsupported key %s', sub_key);
    end

    arr_sample_mat = reshape(arr, nsamples, []);
    struct_cell = cell(1, nsamples);

    for j = 1:nsamples
        struct_cell{j} = converter(arr_sample_mat(j, :), remainder_size, ...
                                   ft.inside);
    end

    arr_struct = struct(sub_key, struct_cell);
    ft.(key) = arr_struct;

    key2method = struct();
    key2method.avg = 'average';
    key2method.trial = 'rawtrial';

    ft.method = key2method.(key);

function arr = convert_ft_source_vector2array(arr_vec, remainder_size, ...
                                              is_inside)
    arr = reshape(arr_vec, [remainder_size 1]);
    arr(~is_inside) = NaN;

function arr_cell = convert_ft_source_vector2cell(arr_vec, remainder_size, ...
                                                  is_inside)
    nsensors = remainder_size(1);
    remainder_remainder_size = [remainder_size(2:end) 1 1];

    arr_sens_mat = reshape(arr_vec, nsensors, []);
    arr_cell = cell(nsensors, 1);
    for k = 1:nsensors
        if is_inside(k)
            arr_cell{k} = reshape(arr_sens_mat(k, :), ...
                                  remainder_remainder_size);
        end
    end

function ft = ds_copy_fields_with_matching_sample_size(ft, ds, keys)
    if ~isfield(ds, 'sa')
        return
    end

    nsamples = size(ds.samples, 1);
    for k = 1:numel(keys)
        key = keys{k};
        if isfield(ds.sa, key)
            value = ds.sa.(key);
            if size(value, 1) == nsamples
                ft.(key) = value;
            end
        end
    end

function ft = init_ft_source_fields(ds)
    ft = struct();
    % for MEEG source data, set the .inside field
    assert(is_ds_source_struct(ds));
    [dim, pos_index] = cosmo_dim_find(ds, 'pos', true);

    % set the inside field
    inside_ds = cosmo_slice(ds, 1, 1);
    inside_ds.samples(:) = 1;

    inside_arr = cosmo_unflatten(inside_ds, 2, 'matrix_labels', {'pos'});

    inside_arr_pos_first = shiftdim(inside_arr, pos_index);
    n = size(inside_arr_pos_first, 1);
    inside_matrix_pos_first = reshape(inside_arr_pos_first, n, []);
    ft.inside = any(inside_matrix_pos_first, 2);

    if cosmo_isfield(ds, 'a.meeg.dim')
        ft.dim = ds.a.meeg.dim;
    end

    if cosmo_isfield(ds, 'a.meeg.tri')
        ft.tri = ds.a.meeg.tri;
    end

function ft = build_ft(ds, unused)
    % get fieldtrip-specific fields from header
    [ft, samples_label, dim_labels] = get_ft_samples(ds);

    % set dimord
    underscore2dash = @(x)strrep(x, '_', '-');

    dimord_labels = [samples_label; ...
                     cellfun(underscore2dash, dim_labels(:), ...
                           'UniformOutput', false)];
    if ~is_ds_source_struct(ds)
        ft.dimord = cosmo_strjoin(dimord_labels, '_');
    end

    % store each feature attribute dimension value
    ndim = numel(dim_labels);
    for dim = 1:ndim
        dim_label = dim_labels{dim};
        dim_value = ds.a.fdim.values{dim};
        switch dim_label
            case 'mom'
                % ignore
                continue
            case 'chan'
                dim_label = 'label';
            case 'pos'
                dim_value = dim_value';
            otherwise
                % time or freq; fieldtrip will puke with column vector
                dim_value = dim_value(:)';
        end
        ft.(dim_label) = dim_value;
    end

    ft = ds_copy_fields_with_matching_sample_size(ft, ds, ...
                                                  {'rpt', 'trialinfo', 'cumtapcnt'});

    % if fieldtrip is present
    if cosmo_check_external('fieldtrip', false) && ...
                isequal(ft_datatype(ft), 'unknown')
        cosmo_warning('fieldtrip does not approve of this dataset');
    end