cosmo check neighborhood skl

function is_ok=cosmo_check_neighborhood(nbrhood,varargin)
% check that a neighborhood is kosher
%
% is_ok=cosmo_check_neighborhood(nbrhood,[raise])
%
% Inputs:
%     nbrhood               neighborhood struct, for example from
%                           cosmo_spherical_neighborhood,
%                           cosmo_surficial_neighborhood,
%                           surfing_interval_neighborhood, or
%                           cosmo_meeg_chan_neighborhood
%      raise                (optional) if set to true (the default), an
%                           error is thrown if nbrhood is not kosher
%      'show_warning',w   If true (the default), then a warning is shown
%                           if nbrhood has no origin
%
% Output:
%      is_ok                true if nbrhood is kosher, false otherwise
%
%
% Examples:
%     ds=cosmo_synthetic_dataset();
%     nbrhood=cosmo_spherical_neighborhood(ds,'radius',1,'progress',false);
%     cosmo_check_neighborhood(nbrhood)
%     %|| true
%     %
%     cosmo_check_neighborhood(2)
%     %|| error('neighborhood is not a struct')
%     %
%     % error can be silenced
%     cosmo_check_neighborhood(2,false)
%     %|| false
%     %
%     fa=nbrhood.fa;
%     nbrhood=rmfield(nbrhood,'fa');
%     cosmo_check_neighborhood(nbrhood)
%     %|| error('field ''fa'' missing in neighborhood')
%     %
%     nbrhood.fa=fa;
%     nbrhood.neighbors{2}=-1;
%     cosmo_check_neighborhood(nbrhood)
%     %|| error('.neighbors{2} is not a row vector with integers')
%     %
%     nbrhood.neighbors{2}=[1];
%     nbrhood.fa.chan=[3 2 1];
%     cosmo_check_neighborhood(nbrhood)
%     %|| error('fa.chan has 3 values in dimension 2, expected 6')
%
% See also: cosmo_spherical_neighborhood, surfing_interval_neighborhood
%           cosmo_surficial_neighborhood, cosmo_meeg_chan_neighborhood
%
% #   For CoSMoMVPA's copyright information and license terms,   #
% #   see the COPYING file distributed with CoSMoMVPA.           #

    [raise, ds, show_warning]=process_input(varargin{:});

    is_ok=false;

    checkers={@check_basis,...
                @check_neighbors,...
                @check_origin_matches};


    for j=1:numel(checkers)
        checker=checkers{j};
        msg=checker(nbrhood, ds, show_warning);
        if ~isempty(msg)
            if raise
                error([func2str(checker) ': ' msg]);
            end
            return
        end
    end


    % treat like dataset
    nfeatures=numel(nbrhood.neighbors);
    nbrhood=rmfield(nbrhood,'neighbors');

    if isfield(nbrhood,'fa')
        nbrhood.samples=zeros(0,nfeatures);
    else
        nbrhood.samples=zeros(nfeatures,0);
    end

    if isfield(nbrhood,'origin')
        nbrhood=rmfield(nbrhood,'origin');
    end

    is_ok=cosmo_check_dataset(nbrhood,raise);


function [infix, absent_infix]=get_attr_infix(nbrhood)
    m=cosmo_match({'sa','fa'},fieldnames(nbrhood));
    if sum(m)~=1
        error('exactly one of .sa or .fa must be present');
    end

    if m(1)
        infix='s';
        absent_infix='f';
    else
        infix='f';
        absent_infix='s';
    end

function msg=check_basis(nbrhood, ds, show_warning)
    msg='';
    if ~isstruct(nbrhood)
        msg='neighborhood is not a struct';
        return
    end

    keys={'neighbors','a','fa','sa','origin'};
    delta=setdiff(fieldnames(nbrhood),keys);
    if ~isempty(delta)
        first=delta{1};
        msg=sprintf('field ''%s'' not allowed in neighborhood',first);
        return
    end

    if ~isfield(nbrhood,'neighbors')
        error('missing field .neighbors');
    end


function tf=is_positive_int_row_vector(x)
    tf=isempty(x) || ...
            (size(x,1)==1 && min(x)>=1 && all(round(x)==x));


function msg=check_neighbors(nbrhood, ds, show_warning)
    msg='';

    nbrs=nbrhood.neighbors;

    if ~iscell(nbrs)
        msg='.neighbors is not a cell';
        return;
    end

    if size(nbrs,2)~=1
        msg='.neighbors is not of size Kx1';
        return;
    end


    has_dataset=~isempty(ds);
    if has_dataset
        infix=get_attr_infix(nbrhood);
        switch infix
            case 's'
                dim=1;
            case 'f'
                dim=2;
        end

        max_feature=size(ds.samples,dim);
    else
        max_feature=Inf;
    end

    for j=1:numel(nbrs)
        nbr=nbrs{j};

        if any(nbr>max_feature)
            msg=sprintf(['.neighbors{%d} exceeds the number '...
                            'of features (%d) in the dataset']',...
                                j, max_feature);
            return
        end

        if ~is_positive_int_row_vector(nbr)
            msg=sprintf('.neighbors{%d} is not an integer row vector',j);
            return;
        end
    end


function msg=check_origin_matches(nbrhood, ds, show_warning)
    msg='';

    % legacy neighborhood, do not throw an exception
    if show_warning && ~isfield(nbrhood,'origin')
        cosmo_warning(['Legacy warning: newer versions of CoSMoMVPA '...
                        'require a field .origin in the neighborhood '...
                        'struct']);
        return;
    end

    if isempty(ds)
        % no dataset, so no further checks
        return;
    end

    origin=nbrhood.origin;


    [infix, absent_infix]=get_attr_infix(nbrhood);


    dim_name=[infix 'dim'];

    if isfield(origin, 'a')
        origin_a=origin.a;
        if isfield(ds,'a')
            ds_a=ds.a;

            if isfield(ds_a,dim_name)
                if isfield(origin_a,dim_name)
                    msg=check_xdim_matches(origin_a.(dim_name), ...
                                            ds_a.(dim_name), ...
                                            dim_name);
                    if ~isempty(msg)
                        return;
                    end

                    msg=check_xa_matches(origin,ds,infix);
                    if ~isempty(msg)
                        return;
                    end

                    origin_a=rmfield(origin_a,dim_name);
                end
                ds_a=rmfield(ds_a,dim_name);
            end

            if ~cosmo_isequaln(origin_a, ds_a)
                error('.a mismatch between dataset and neighborhood');
            end
        end
    end

function msg=check_xa_matches(origin,ds,infix)
    msg='';
    attr_name=[infix 'a'];
    dim_name=[infix 'dim'];
    keys=origin.a.(dim_name).labels;

    for k=1:numel(keys)
        key=keys{k};

        if ~cosmo_isfield(origin,[attr_name '.' key]) || ...
                ~cosmo_isequaln(ds.(attr_name).(key),...
                        origin.(attr_name).(key))
            msg=sprintf(['.%sa.%s mismatch between dataset and '...
                    'neighborhood'],infix,key);
            return
        end
    end


function msg=check_xdim_matches(origin_xdim, ds_xdim, dim_name)
    msg='';
    keys=fieldnames(ds_xdim);
    if ~isequal(sort(keys),sort(fieldnames(origin_xdim)))
        msg=sprintf(['.a.%s key mismatch between dataset '...
                    'and neighborhood'], dim_name);
                return;
    end

    for k=1:numel(keys)
        key=keys{k};

        origin_v=origin_xdim.(key);
        ds_v=ds_xdim.(key);

        if ~(iscell(origin_v) && iscell(ds_v))
            msg=sprintf('.a.%s ''%s'' must be a cell',...
                                    dim_name, key);
            return
        end

        if numel(origin_v)~=numel(ds_v)
            msg=sprintf(['.a.%s size mismatch between ',...
                            'dataset and neighborhood'],...
                                    dim_name, key);
            return
        end

        for j=1:numel(origin_v)
            if ~(cosmo_isequaln(origin_v{j},ds_v{j}) || ...
                            cosmo_isequaln(origin_v{j},ds_v{j}'))
                msg=sprintf(['.a.%s ''%s'' value mismatch '...
                            'between dataset and neighborhood'], ...
                                    dim_name, key);
                return
            end
        end
    end


function [raise, ds, show_warning]=process_input(varargin)
    raise=true;
    ds=[];
    show_warning=true;

    narg=numel(varargin);
    k=0;
    while k<narg
        k=k+1;
        arg=varargin{k};

        if islogical(arg)
            raise=arg;
        elseif isstruct(arg)
            ds=arg;
            cosmo_check_dataset(ds);
        elseif ischar(arg)
            if k==narg
                error('missing argument after ''show_warning''');
            end
            switch arg
                case 'show_warning'
                    k=k+1;
                    show_warning=varargin{k};
                otherwise
                    error('illegal argument at position %d', k);
            end
        end
    end