cosmo meeg chan neighbors skl

function neighbors = cosmo_meeg_chan_neighbors(ds, varargin)
    % find neighbors of MEEG channels
    %
    % neighbors=cosmo_meeg_chan_neighbors(ds, ...)
    %
    % Inputs:
    %   ds                  MEEG dataset struct
    %   'label', lab        Labels to return in output, one of:
    %                       'layout'    : determine neighbors based on layout
    %                                     associated with ds (default). All
    %                                     labels in the layout are used as
    %                                     center labels.
    %                       'dataset'   : determine neighbors based on labels
    %                                     present in ds. Only labels present in
    %                                     ds are used as center labels
    %                       {x1,...,xn} : use labels x1 ... xn
    %   'chantype', tp      (optional) channel type of neighbors, can be one of
    %                       'eeg', 'meg_planar', 'meg_axial', or
    %                       'meg_combined_from_planar'.
    %                       Use 'all' to use all channel types associated with
    %                       lab, and 'all_combined' to use
    %                       'meg_combined_from_planar' with all other channel
    %                       types in ds except for 'meg_planar'.
    %                       If there is only one channel type associated with
    %                       lab, then this argument is not required.
    %   'radius', r         } select neighbors either within radius r, grow
    %   'count', c          } the radius to get neighbors are c locations,
    %   'delaunay', true    } or use Delaunay triangulation to find direct
    %                       } neighbors for each channel.
    %                       } These three options are mutually exclusive
    %
    %
    % Output:
    %   neighbors           Kx1 struct for K center labels, with fields:
    %     .label            center label
    %     .neighblabel      cell with labels of neighbors
    %
    % Examples:
    %     % (This example requires FieldTrip)
    %     cosmo_skip_test_if_no_external('fieldtrip');
    %     %
    %     % get neighbors within radius of .3 for EEG dataset
    %     ds=cosmo_synthetic_dataset('type','meeg',...
    %                                        'sens','eeg1010','size','big');
    %     % show all channel labels
    %     cosmo_disp(ds.a.fdim.values{1});
    %     %|| { 'TP10'  'TP7'  'TP8' ... 'A2'  'M1'  'M2'   }@1x94
    %     %
    %     % simulate the case where some channels are missing; here, every 7-th
    %     % channels is removed
    %     ds=cosmo_slice(ds,mod(ds.fa.chan,7)~=2,2);
    %     ds=cosmo_dim_prune(ds);
    %     %
    %     % show remaining channel labels
    %     cosmo_disp(ds.a.fdim.values{1});
    %     %|| { 'TP10'  'TP8'  'TP9' ... 'A1'  'A2'  'M2'   }@1x80
    %     %
    %     % get neighbors for the channel layout associated with this
    %     % dataset. This layout ('EEG1010.lay') has 88 channel positions,
    %     % of which the last two are ignored because they are 'COMNT' and
    %     % 'SCALE'
    %     nbrs=cosmo_meeg_chan_neighbors(ds,'radius',.3);
    %     cosmo_disp(nbrs,'edgeitems',1);
    %     %|| <struct>@86x1
    %     %||    (1,1) .label
    %     %||            'Fp1'
    %     %||          .neighblabel
    %     %||            { 'Fp1'
    %     %||               :
    %     %||              'FC1' }@21x1
    %     %||      :           :
    %     %||    (86,1).label
    %     %||            'I2'
    %     %||          .neighblabel
    %     %||            { 'P4'
    %     %||               :
    %     %||              'I2' }@16x1
    %     %
    %     % since the dataset has only 80 channels, 74 of which are in the
    %     % layout, using the dataset's labels only (with the 'labels'
    %     % argument) returns
    %     % only neighbors for channels in the dataset
    %     nbrs=cosmo_meeg_chan_neighbors(ds,'radius',.3,...
    %                                         'label','dataset');
    %     cosmo_disp(nbrs,'edgeitems',1);
    %     %|| <struct>@74x1
    %     %||    (1,1) .label
    %     %||            'Fp1'
    %     %||          .neighblabel
    %     %||            { 'Fp1'
    %     %||               :
    %     %||              'FC1' }@21x1
    %     %||      :           :
    %     %||    (74,1).label
    %     %||            'I2'
    %     %||          .neighblabel
    %     %||            { 'P4'
    %     %||               :
    %     %||              'I2' }@16x1
    %
    %     % (This example requires FieldTrip)
    %     cosmo_skip_test_if_no_external('fieldtrip');
    %     %
    %     % get neighbors at 4 neighboring sensor location for
    %     % planar neuromag306 channels
    %     ds=cosmo_synthetic_dataset('type','meeg','size','big');
    %     nbrs=cosmo_meeg_chan_neighbors(ds,...
    %                     'chantype','meg_planar','count',4);
    %     cosmo_disp(nbrs,'edgeitems',1);
    %     %|| <struct>@204x1
    %     %||    (1,1)  .label
    %     %||             'MEG0113'
    %     %||           .neighblabel
    %     %||             { 'MEG0113'
    %     %||               'MEG0112'
    %     %||               'MEG0122'
    %     %||               'MEG0133' }
    %     %||      :            :
    %     %||    (204,1).label
    %     %||             'MEG2643'
    %     %||           .neighblabel
    %     %||             { 'MEG2423'
    %     %||               'MEG2422'
    %     %||               'MEG2642'
    %     %||               'MEG2643' }
    %
    %     % (This example requires FieldTrip)
    %     cosmo_skip_test_if_no_external('fieldtrip');
    %     %
    %     % get neighbors at 4 neighboring sensor location for
    %     % planar neuromag306 channels, but with the center labels
    %     % the set of combined planar channels
    %     % (there are 8 channels in the .neighblabel fields, because
    %     %  there are two planar channels per combined channel)
    %     ds=cosmo_synthetic_dataset('type','meeg','size','big');
    %     nbrs=cosmo_meeg_chan_neighbors(ds,...
    %                     'chantype','meg_combined_from_planar','count',4);
    %     cosmo_disp(nbrs,'edgeitems',1);
    %     %|| <struct>@102x1
    %     %||    (1,1)  .label
    %     %||             'MEG0112+0113'
    %     %||           .neighblabel
    %     %||             { 'MEG0112'
    %     %||                  :
    %     %||               'MEG0343' }@8x1
    %     %||      :              :
    %     %||    (102,1).label
    %     %||             'MEG2642+2643'
    %     %||           .neighblabel
    %     %||             { 'MEG2422'
    %     %||                  :
    %     %||               'MEG2643' }@8x1
    %
    %     % (This example requires FieldTrip)
    %     cosmo_skip_test_if_no_external('fieldtrip');
    %     %
    %     % As above, but now use both the axial and planar channels.
    %     % Here the axial channels only have axial neighbors, and the planar
    %     % channels only have planar neighbors
    %     ds=cosmo_synthetic_dataset('type','meeg','size','big');
    %     nbrs=cosmo_meeg_chan_neighbors(ds,...
    %                            'chantype','all','count',4);
    %     cosmo_disp(nbrs,'edgeitems',1);
    %     %|| <struct>@306x1
    %     %||    (1,1)  .label
    %     %||             'MEG0111'
    %     %||           .neighblabel
    %     %||             { 'MEG0111'
    %     %||               'MEG0121'
    %     %||               'MEG0131'
    %     %||               'MEG0341' }
    %     %||      :            :
    %     %||    (306,1).label
    %     %||             'MEG2643'
    %     %||           .neighblabel
    %     %||             { 'MEG2423'
    %     %||               'MEG2422'
    %     %||               'MEG2642'
    %     %||               'MEG2643' }
    %
    %     % (This example requires FieldTrip)
    %     cosmo_skip_test_if_no_external('fieldtrip');
    %     %
    %     % As above, but now use both the axial and planar channels with
    %     % center labels for the planar channels from the combined_planar set.
    %     % Here the axial center channels have 4 axial neighbors each, while
    %     % the planar_combined channels have 8 planar (uncombined) neighbors
    %     % each.
    %     ds=cosmo_synthetic_dataset('type','meeg','size','big');
    %     nbrs=cosmo_meeg_chan_neighbors(ds,...
    %                            'chantype','all_combined','count',4);
    %     cosmo_disp(nbrs,'edgeitems',1);
    %     %|| <struct>@204x1
    %     %||    (1,1)  .label
    %     %||             'MEG0111'
    %     %||           .neighblabel
    %     %||             { 'MEG0111'
    %     %||               'MEG0121'
    %     %||               'MEG0131'
    %     %||               'MEG0341' }
    %     %||      :              :
    %     %||    (204,1).label
    %     %||             'MEG2642+2643'
    %     %||           .neighblabel
    %     %||             { 'MEG2422'
    %     %||                  :
    %     %||               'MEG2643' }@8x1
    %
    %
    % Notes:
    %  - this function returns a struct similar to FieldTrip's
    %    ft_prepare_neighbors, but not identical:
    %    * a center labels can be a neighbor of itself
    %    * the neighbors are similar but not identical to FieldTrip's
    %      ft_prepare_neighbors
    %  - for searchlight and clustering purposes, use
    %    cosmo_meeg_chan_neighborhood
    %
    % See also: cosmo_meeg_chantype, ft_prepare_neighbours,
    %           cosmo_meeg_chan_neighborhood
    %
    % #   For CoSMoMVPA's copyright information and license terms,   #
    % #   see the COPYING file distributed with CoSMoMVPA.           #

    default.label = 'layout';
    opt = cosmo_structjoin(default, varargin);

    if isfield(opt, 'chantype')
        chantype = opt.chantype;
        chantypes = get_chantypes(ds, chantype);
        n = numel(chantypes);
        neighbors_cell = cell(n, 1);
        for k = 1:n
            opt.chantype = chantypes{k};
            neighbors_cell{k} = get_neighbors_with_chantype(ds, opt);
        end

        neighbors = cat(1, neighbors_cell{:});

        if strcmp(chantype, 'all')
            neighbors = add_missing_channels(ds, neighbors);
        end

    else
        neighbors = get_neighbors_with_chantype(ds, opt);
    end

    neighbors = reorder_neighbors(ds, neighbors);

function full_neighbors = add_missing_channels(ds, neighbors)
    ds_label = get_dataset_channel_label(ds);
    nbr_label = {neighbors.label};

    missing = setdiff(ds_label, nbr_label);
    n = numel(missing);
    if n == 0
        full_neighbors = neighbors;
        return
    end

    label = missing(:);
    neighblabel = cell(n, 1);
    for k = 1:n
        neighblabel{k} = cell(1, 0);
    end

    missing_neighbors = struct('label', label, 'neighblabel', neighblabel);
    full_neighbors = cat(1, neighbors, missing_neighbors);

function neighbors = reorder_neighbors(ds, neighbors)
    ds_label = get_dataset_channel_label(ds);
    nbr_label = {neighbors.label};

    if numel(ds_label) ~= numel(nbr_label)
        return
    end

    ds_cell = cellfun(@(x){x}, ds_label, 'UniformOutput', false);
    nbr_cell = cellfun(@(x){x}, nbr_label, 'UniformOutput', false);

    overlap = cosmo_overlap(ds_cell, nbr_cell);
    if ~(all(sum(overlap, 1)) && all(sum(overlap, 2)))
        return
    end

    [i, j] = find(overlap);
    n = numel(i);
    assert(all(j' == 1:n));

    % invert mapping
    ii = i;
    ii(i) = 1:n;

    neighbors = neighbors(ii);
    assert(isequal({neighbors.label}', ds_label(:)));

function chan_types = get_chantypes(ds, chantype)
    chan_types = {chantype};

    switch chantype
        case 'all'
            chan_types = get_all_chantypes_from_ds(ds);
        case 'all_combined'
            % replace meg_planar by meg_planar_combined
            chan_types = get_all_chantypes_from_ds(ds);
            i = find(cosmo_match(chan_types, 'meg_planar'));
            if numel(i) ~= 1
                error(['dataset has no planar channels, therefore '...
                       '''%s'' is an invalid chantype'], ...
                      opt, chan_type);
            end
            chan_types{i} = 'meg_combined_from_planar';
    end

function chan_types = get_all_chantypes_from_ds(ds)
    chan_types = unique(cosmo_meeg_chantype(ds));
    chan_types = setdiff(chan_types, {'unknown'});

function neighbors = get_neighbors_with_chantype(ds, opt)
    [layout, label_keep] = get_layout(ds, opt);

    ds_label = get_dataset_channel_label(ds);
    if isempty(label_keep)
        pos_msk = cosmo_match(layout.label, ds_label);
    else
        pos_msk = cosmo_overlap({ds_label}, layout.child_label) > 0;
    end

    nbr_msk = pairwise_neighbors(layout.pos, pos_msk, opt);

    has_child = isfield(layout, 'child_label');

    label = layout.label;
    n = numel(label);

    neighblabel = cell(n, 1);
    for k = 1:n
        msk = nbr_msk(:, k);
        if has_child
            child_label = layout.child_label(msk);
            neighblabel{k} = intersect(label_keep, cat(1, child_label{:}));
        else
            neighblabel{k} = layout.label(msk);
        end
    end

    neighbors = struct('label', label, 'neighblabel', neighblabel);

function [lay, label] = get_layout(ds, opt)
    ignore_label = {'COMNT', 'SCALE'};

    base_lay = cosmo_meeg_find_layout(ds, opt);

    if ischar(opt.label)
        switch opt.label
            case 'layout'
                use_label = base_lay.label;
            case 'dataset'
                use_label = get_dataset_channel_label(ds);
            otherwise
                error('illegal label %s, use one of: layout, dataset');
        end
    elseif iscellstr(opt.label)
        use_label = opt.label;
    else
        error('illegal label, a string or cellstring is required');
    end

    in_label = cosmo_match(base_lay.label, use_label);
    not_in_ignore = ~cosmo_match(base_lay.label, ignore_label);
    keep_msk = in_label & not_in_ignore;

    if isfield(base_lay, 'parent')
        child_label = base_lay.parent.child_label;
        overlap = cosmo_overlap(child_label, {use_label});
        keep_parent_msk = overlap > 0;
        lay = slice_layout(base_lay.parent, keep_parent_msk);
        label = base_lay.label(keep_msk);
    else
        lay = slice_layout(base_lay, keep_msk);
        label = [];
    end

function lay = slice_layout(lay, to_keep)
    nlabel = size(lay.label, 1);

    keys = fieldnames(lay);
    for k = 1:numel(keys)
        key = keys{k};
        value = lay.(key);

        nvalue = size(value, 1);
        switch nvalue
            case nlabel
                lay.(key) = value(to_keep, :);
            case {0, 1}
                % ok
            otherwise
                error('layout field %s has %d values, must be 1 or %d', ...
                      key, nvalue, nlabel);
        end

    end

function nbrs_msk = pairwise_neighbors(pos, msk, opt)
    assert(size(pos, 1) == numel(msk));

    metric2func = struct();
    metric2func.radius = @pairwise_euclidean_neighbors;
    metric2func.count = @pairwise_nearest_neighbors;
    metric2func.delaunay = @pairwise_delaunay_neighbors;

    metrics = fieldnames(metric2func);
    metric_msk = cosmo_isfield(opt, metrics);
    if sum(metric_msk) ~= 1
        error('Use one of these arguments to define neighbors: %s', ...
              cosmo_strjoin(metrics, ', '));
    end

    metric = metrics{metric_msk};
    func = metric2func.(metric);
    nbrs_msk = func(pos, msk, opt.(metric));

function nbrs_msk = pairwise_euclidean_neighbors(pos, msk, radius)
    d = pairwise_euclidean_distance(pos);
    nbrs_msk = d <= radius;
    nbrs_msk(~msk, :) = false;

function nbrs_msk = pairwise_nearest_neighbors(pos, msk, count)
    d = pairwise_euclidean_distance(pos);
    nbrs_msk = nearest_neighbors_from_distance(d, msk, count);

function nbrs_msk = pairwise_delaunay_neighbors(pos, msk, steps)
    % compute steps-th order neighbors
    raise_error_if_not_two_column_matrix(pos);

    if islogical(steps)
        steps = 0 + steps;
    end

    self_connectivity = 0 + diag(msk > 0);
    connectivity = self_connectivity;

    if steps > 0
        % compute sum_{k=1:steps} direct_nbrs.^k
        direct_nbrs = 0 + pairwise_delaunay_direct_neighbors(pos, msk);
        for k = 1:steps
            connectivity = connectivity * (self_connectivity + direct_nbrs);

            % avoid large values
            connectivity(connectivity > 0) = 1;

            connectivity(~msk, :) = 0;
        end
    end
    nbrs_msk = connectivity > 0;

function nbrs_msk = pairwise_delaunay_direct_neighbors(pos, msk)
    % avoid duplicate sensor positions
    [idxs, unq_pos] = cosmo_index_unique(pos);
    unq_nbrs_msk = pairwise_delaunay_direct_neighbors_unique(unq_pos, msk);

    n = size(pos, 1);
    nbrs_msk = false(n);

    nunq = numel(idxs);
    assert(isequal([1 1] * nunq, size(unq_nbrs_msk)));

    for k = 1:nunq
        for j = 1:nunq
            nbrs_msk(idxs{k}, idxs{j}) = unq_nbrs_msk(k, j);
        end
    end

function nbrs_msk = pairwise_delaunay_direct_neighbors_unique(pos, msk)
    % act a bit like fieldtrip by stretching the coordinates (twice)
    stretch = [1 .5 2];

    % delaunay without applying the mask
    f_pos = delaunay_with_stretch(pos, stretch);

    % delaunay with applying the mask
    idx = find(msk(:)');
    pos_msk = pos(idx, :);
    f_msk = delaunay_with_stretch(pos_msk, stretch);
    f_msk_pos = idx(f_msk);

    % combine them
    f_all = [f_pos; f_msk_pos];

    n = size(pos, 1);
    nbrs_msk = diag(1:n) > 0;

    [nrow, ncol] = size(f_all);
    assert(ncol == 3);
    for col = 1:ncol
        i = f_all(:, col);
        j = f_all(:, mod(col, 3) + 1);
        for row = 1:nrow
            nbrs_msk(i(row), j(row)) = true;
        end
    end

    % make symmetric
    nbrs_msk = nbrs_msk | nbrs_msk';

function d = delaunay_with_stretch(pos, stretches)
    nrows = size(pos, 1);
    if nrows < 3
        % minimal surface, cannot do Delaunay
        d = ones(1, 3);
        d(1:nrows) = 1:nrows;
    else
        n = numel(stretches);
        ds = cell(n, 1);

        for k = 1:n
            stretch_pos = bsxfun(@times, pos, [1 stretches(k)]);
            ds{k} = delaunay(stretch_pos(:, 1), stretch_pos(:, 2));
        end
        d = cat(1, ds{:});
    end

function nbrs_msk = nearest_neighbors_from_distance(d, msk, count)
    % d is an n x n matrix with distances
    n = size(d, 1);
    nbrs_msk = false(n);

    d_msk = d;
    d_msk(~msk, :) = Inf;

    [sd, i] = sort(d_msk);
    for k = 1:n
        last_row = min(n, count);

        radius = sd(last_row, k);
        while last_row < n && sd(last_row + 1, k) == radius
            last_row = last_row + 1;
        end

        if last_row < count || isinf(sd(last_row, k))
            error('Cannot select %d channels: only %d are present', ...
                  count, sum(msk));
        end

        nbrs_msk(i(1:last_row, k), k) = true;
    end

    assert(all(diag(nbrs_msk) == msk(:)));
    assert(all(sum(nbrs_msk, 1) >= count));

function d = pairwise_euclidean_distance(pos)
    raise_error_if_not_two_column_matrix(pos);

    px = pos(:, 1);
    py = pos(:, 2);
    dx = bsxfun(@minus, px, px');
    dy = bsxfun(@minus, py, py');

    d = sqrt(dx.^2 + dy.^2);

function raise_error_if_not_two_column_matrix(pos)
    is_ok = size(pos, 2) == 2 && numel(size(pos)) == 2;
    if ~is_ok
        error('positions must be in Mx2 matrix');
    end

function chan_labels = get_dataset_channel_label(ds)
    % helper function to get labels from dataset
    if iscellstr(ds)
        chan_labels = ds;
    else
        [unused, index, unused, dim_name] = cosmo_dim_find(ds, 'chan', true);
        chan_labels = ds.a.(dim_name).values{index};
    end