cosmo dim match skl

function msk=cosmo_dim_match(ds, varargin)
% return a mask indicating match of dataset dimensions with values
%
% msk=cosmo_match(ds, dim_label, dim_values1[...]
%
% Inputs:
%   ds                dataset struct or neighborhood struct
%   haystack*         numeric vector, or cell with strings, or string.
%                     A string is interpreted as the name of a feature
%                     dimension (e.g. 'i','j' or 'k' in fmri datasets;
%                     'chan', 'time', or 'freq' in MEEG datasets), and its
%                     respective values (from ds.a.fdim.values{dim}, where
%                     dim is the dimension corresponding to haystack) as
%                     indexed by ds.fa.(haystack) are used as haystack.
%   needle*           numeric vector, or cell with strings. A string is
%                     also allowed and interpreted as {needle}.
%                     A function handle is also allowed, in which case the
%                     value use for needle is the function applied to
%                     the corresponding value in ds.a.fdim.values.
%   dim               If the last argument, it sets the dimension along
%                     which dim_label has to be found. If omitted it
%                     finds the dimension in the dataset.
%
% Output:
%   msk               boolean array of the same size as haystack, with
%                     true where the value in haystack is equal to at least
%                     one value in needle. If multiple needle/haystack
%                     pairs are provided, then the haystack inputs should
%                     have the same number of elements, and msk contains
%                     the intersection of the individual masks.
%
% Examples:
%
%     % in an fMRI dataset, get all features with the first voxel dimension
%     % between 5 and 10, inclusive
%     ds=cosmo_synthetic_dataset('type','fmri','size','huge');
%     cosmo_disp(ds.a.fdim.values{1});
%     %|| [ 1         2         3  ...  18        19        20 ]@1x20
%     cosmo_disp(ds.fa.i)
%     %|| [ 1         2         3  ...  18        19        20 ]@1x6460
%     msk=cosmo_dim_match(ds,'i',5:10);
%     ds_sel=cosmo_slice(ds,msk,2);
%     % no pruning, so the fdim.values are not changed. A subset of
%     % features is selected
%     cosmo_disp(ds_sel.a.fdim.values{1});
%     %|| [ 1         2         3  ...  18        19        20 ]@1x20
%     cosmo_disp(ds_sel.fa.i)
%     %|| [ 5         6         7  ...  8         9        10 ]@1x1938
%
%     % For an MEEG dataset, get a selection of some channels
%     ds=cosmo_synthetic_dataset('type','meeg','size','huge');
%     cosmo_disp(ds.a.fdim.values{1},'edgeitems',2);
%     %|| { 'MEG0111'  'MEG0112'  ... 'MEG2642'  'MEG2643'   }@1x306
%     cosmo_disp(ds.fa.chan)
%     %|| [ 1         2         3  ...  304       305       306 ]@1x5202
%     %
%     % select channels
%     msk=cosmo_dim_match(ds,'chan',{'MEG1843','MEG2441'});
%     ds_sel=cosmo_slice(ds,msk,2);
%     %
%     % apply pruning, so that the .fa.chan goes from 1:nf, with nf the
%     % number of channels that were selected
%     ds_pruned=cosmo_dim_prune(ds_sel);
%     %
%     % show result
%     cosmo_disp(ds_pruned.a.fdim.values{1}); % 'chan' is first dimension
%     %|| { 'MEG1843'
%     %||   'MEG2441' }
%     cosmo_disp(ds_pruned.fa.chan)
%     %|| [ 1         2         1  ...  2         1         2 ]@1x34
%     %
%     % For the same MEEG dataset, get a selection of time points between 0
%     % and .3 seconds. A function handle is used to select the timepoints
%     selector=@(x) 0<=x & x<=.3; % use element-wise logical-and
%     msk=cosmo_dim_match(ds,'time',selector);
%     ds_sel=cosmo_slice(ds,msk,2);
%     ds_pruned=cosmo_dim_prune(ds_sel);
%     %
%     % show result
%     cosmo_disp(ds_pruned.a.fdim.values{2}); % 'time' is second dimension
%     %|| [ 0      0.05       0.1  ...  0.2      0.25       0.3 ]@1x7
%     cosmo_disp(ds_pruned.fa.time)
%     %|| [ 1         1         1  ...  7         7         7 ]@1x2142
%     %
%     % For the same MEEG dataset, compute a conjunction mask of the
%     % channels and time points selected above
%     msk=cosmo_dim_match(ds,'chan',{'MEG1843','MEG2441'},'time',selector);
%     ds_sel=cosmo_slice(ds,msk,2);
%     ds_pruned=cosmo_dim_prune(ds_sel);
%     %
%     % show result
%     cosmo_disp(ds_pruned.a.fdim.values); % 'chan' and 'time'
%     %|| { { 'MEG1843'  'MEG2441' }
%     %||   [ 0      0.05       0.1  ...  0.2      0.25       0.3 ]@1x7 }
%     cosmo_disp(ds_pruned.fa.chan)
%     %|| [ 1         2         1  ...  2         1         2 ]@1x14
%     cosmo_disp(ds_pruned.fa.time)
%     %|| [ 1         1         2  ...  6         7         7 ]@1x14
%
% Notes
%  - when haystack or needle are numeric vectors or cells of strings,
%    then this function behaves like cosmo_match (and does not consider
%    information in its first input argument ds).
%  - to remove dimension elements not included in the mask, use
%    cosmo_dim_prune. When the dataset is transformed back using
%    cosmo_map2{meeg,fmri,surface} it will not have these elements.
%    The only real use case is in MEEG datasets to remove time, channel, or
%    frequency elements; for fmri or surface datasets it is a bad idea to
%    use cosmo_dim_prune.
%
% See also: cosmo_match, cosmo_dim_prune
%
% #   For CoSMoMVPA's copyright information and license terms,   #
% #   see the COPYING file distributed with CoSMoMVPA.           #

    [dim_labels, dim_values, expected_dim]=process_input(ds, varargin{:});

    ndim=numel(dim_labels);
    for k=1:ndim
        dim_label=dim_labels{k};
        dim_value=dim_values{k};
        [haystack, needle, found_dim]=match_single_dim(ds, dim_label, ...
                                                           dim_value);

        if ~isempty(expected_dim) && expected_dim~=found_dim
            error(['label ''%s'' was expected in dimension %d, '...
                                'was found in %d'],...
                            dim_label, expected_dim, found_dim);
        end

        dim_msk=cosmo_match(haystack, needle);

        if k==1
            msk_length=numel(haystack);
            assert_mask_of_proper_length(msk_length, ds, found_dim);

            msk=dim_msk;
        else
            if ~isequal(size(msk),size(dim_msk))
                error('size mismatch for dimension label ''%s''',...
                        dim_label);
            end
            msk=msk & dim_msk;
            expected_dim=found_dim;
        end
    end


function assert_mask_of_proper_length(msk_length, ds, dim)
    % since the calling function ensures the size is properly set,
    % the size should be fine; this function should never throw an error
    if isfield(ds,'neighbors')
        assert(numel(ds.neighbors)==msk_length);
    else
        assert(size(ds.samples,dim)==msk_length);
    end


function [dim_labels, dim_values, dim]=process_input(ds, varargin)
% get dimension labels and values
% if the number of arguments in varargin is odd, then the last element is
% the dimension along which dim_labels and dim_values are to be found;
% otherwise it is set to empty
    if ~isstruct(ds)
        error('first argument must be a struct');
    end

    is_neighborhood=isfield(ds,'neighbors');
    if is_neighborhood
        cosmo_check_neighborhood(ds);
    else
        cosmo_check_dataset(ds);
    end

    narg=numel(varargin);
    ndim=floor(narg/2);
    dim_labels=varargin(1:2:(ndim*2));
    dim_values=varargin(2:2:(ndim*2));
    for k=1:ndim
        dim_label=dim_labels{k};

        if ~ischar(dim_label)
            error('argument %d must be a string', k*2);
        end

        dim_value=dim_values{k};

        if ~isvector(dim_value)
            error('argument %d must be a vector', k*2+1);
        end

        if ~(ischar(dim_value) || ...
                    iscellstr(dim_value) || ...
                    isnumeric(dim_value) || ...
                    isa(dim_value,'function_handle'))
            error(['argument %d must be a string, cell string, '...
                        'numeric vector, or function handle']);
        end
    end

    if mod(narg,2)==1
        dim=varargin{end};
    else
        dim=[];
    end


function [haystack, needle, found_dim]=match_single_dim(ds, haystack, ...
                                                          needle)

    % get value for needle and haystack
    [found_dim, index, attr_name, dim_name]=cosmo_dim_find(ds, ...
                                                haystack, true);

    vs=ds.a.(dim_name).values{index};
    if isa(needle,'function_handle')
        match_mask=needle(vs);
    else
        match_mask=cosmo_match(vs,needle);
    end

    % set new value based on indices of the matching mask
    needle=find(match_mask);
    haystack=ds.(attr_name).(ds.a.(dim_name).labels{index});