function is_ok = cosmo_check_neighborhood(nbrhood, varargin)
% check that a neighborhood is kosher
%
% is_ok=cosmo_check_neighborhood(nbrhood,[raise])
%
% Inputs:
% nbrhood neighborhood struct, for example from
% cosmo_spherical_neighborhood,
% cosmo_surficial_neighborhood,
% surfing_interval_neighborhood, or
% cosmo_meeg_chan_neighborhood
% raise (optional) if set to true (the default), an
% error is thrown if nbrhood is not kosher
% 'show_warning',w If true (the default), then a warning is shown
% if nbrhood has no origin
%
% Output:
% is_ok true if nbrhood is kosher, false otherwise
%
%
% Examples:
% ds=cosmo_synthetic_dataset();
% nbrhood=cosmo_spherical_neighborhood(ds,'radius',1,'progress',false);
% cosmo_check_neighborhood(nbrhood)
% %|| true
% %
% cosmo_check_neighborhood(2)
% %|| error('neighborhood is not a struct')
% %
% % error can be silenced
% cosmo_check_neighborhood(2,false)
% %|| false
% %
% fa=nbrhood.fa;
% nbrhood=rmfield(nbrhood,'fa');
% cosmo_check_neighborhood(nbrhood)
% %|| error('field ''fa'' missing in neighborhood')
% %
% nbrhood.fa=fa;
% nbrhood.neighbors{2}=-1;
% cosmo_check_neighborhood(nbrhood)
% %|| error('.neighbors{2} is not a row vector with integers')
% %
% nbrhood.neighbors{2}=[1];
% nbrhood.fa.chan=[3 2 1];
% cosmo_check_neighborhood(nbrhood)
% %|| error('fa.chan has 3 values in dimension 2, expected 6')
%
% See also: cosmo_spherical_neighborhood, surfing_interval_neighborhood
% cosmo_surficial_neighborhood, cosmo_meeg_chan_neighborhood
%
% # For CoSMoMVPA's copyright information and license terms, #
% # see the COPYING file distributed with CoSMoMVPA. #
[raise, ds, show_warning] = process_input(varargin{:});
is_ok = false;
checkers = {@check_basis, ...
@check_neighbors, ...
@check_origin_matches};
for j = 1:numel(checkers)
checker = checkers{j};
msg = checker(nbrhood, ds, show_warning);
if ~isempty(msg)
if raise
error([func2str(checker) ': ' msg]);
end
return
end
end
% treat like dataset
nfeatures = numel(nbrhood.neighbors);
nbrhood = rmfield(nbrhood, 'neighbors');
if isfield(nbrhood, 'fa')
nbrhood.samples = zeros(0, nfeatures);
else
nbrhood.samples = zeros(nfeatures, 0);
end
if isfield(nbrhood, 'origin')
nbrhood = rmfield(nbrhood, 'origin');
end
is_ok = cosmo_check_dataset(nbrhood, raise);
function [infix, absent_infix] = get_attr_infix(nbrhood)
m = cosmo_match({'sa', 'fa'}, fieldnames(nbrhood));
if sum(m) ~= 1
error('exactly one of .sa or .fa must be present');
end
if m(1)
infix = 's';
absent_infix = 'f';
else
infix = 'f';
absent_infix = 's';
end
function msg = check_basis(nbrhood, ds, show_warning)
msg = '';
if ~isstruct(nbrhood)
msg = 'neighborhood is not a struct';
return
end
keys = {'neighbors', 'a', 'fa', 'sa', 'origin'};
delta = setdiff(fieldnames(nbrhood), keys);
if ~isempty(delta)
first = delta{1};
msg = sprintf('field ''%s'' not allowed in neighborhood', first);
return
end
if ~isfield(nbrhood, 'neighbors')
error('missing field .neighbors');
end
function tf = is_positive_int_row_vector(x)
tf = isempty(x) || ...
(size(x, 1) == 1 && min(x) >= 1 && all(round(x) == x));
function msg = check_neighbors(nbrhood, ds, show_warning)
msg = '';
nbrs = nbrhood.neighbors;
if ~iscell(nbrs)
msg = '.neighbors is not a cell';
return
end
if size(nbrs, 2) ~= 1
msg = '.neighbors is not of size Kx1';
return
end
has_dataset = ~isempty(ds);
if has_dataset
infix = get_attr_infix(nbrhood);
switch infix
case 's'
dim = 1;
case 'f'
dim = 2;
end
max_feature = size(ds.samples, dim);
else
max_feature = Inf;
end
for j = 1:numel(nbrs)
nbr = nbrs{j};
if any(nbr > max_feature)
msg = sprintf(['.neighbors{%d} exceeds the number '...
'of features (%d) in the dataset'], ...
j, max_feature);
return
end
if ~is_positive_int_row_vector(nbr)
msg = sprintf('.neighbors{%d} is not an integer row vector', j);
return
end
end
function msg = check_origin_matches(nbrhood, ds, show_warning)
msg = '';
% legacy neighborhood, do not throw an exception
if show_warning && ~isfield(nbrhood, 'origin')
cosmo_warning(['Legacy warning: newer versions of CoSMoMVPA '...
'require a field .origin in the neighborhood '...
'struct']);
return
end
if isempty(ds)
% no dataset, so no further checks
return
end
origin = nbrhood.origin;
[infix, absent_infix] = get_attr_infix(nbrhood);
dim_name = [infix 'dim'];
if isfield(origin, 'a')
origin_a = origin.a;
if isfield(ds, 'a')
ds_a = ds.a;
if isfield(ds_a, dim_name)
if isfield(origin_a, dim_name)
msg = check_xdim_matches(origin_a.(dim_name), ...
ds_a.(dim_name), ...
dim_name);
if ~isempty(msg)
return
end
msg = check_xa_matches(origin, ds, infix);
if ~isempty(msg)
return
end
origin_a = rmfield(origin_a, dim_name);
end
ds_a = rmfield(ds_a, dim_name);
end
if ~cosmo_isequaln(origin_a, ds_a)
error('.a mismatch between dataset and neighborhood');
end
end
end
function msg = check_xa_matches(origin, ds, infix)
msg = '';
attr_name = [infix 'a'];
dim_name = [infix 'dim'];
keys = origin.a.(dim_name).labels;
for k = 1:numel(keys)
key = keys{k};
if ~cosmo_isfield(origin, [attr_name '.' key]) || ...
~cosmo_isequaln(ds.(attr_name).(key), ...
origin.(attr_name).(key))
msg = sprintf(['.%sa.%s mismatch between dataset and '...
'neighborhood'], infix, key);
return
end
end
function msg = check_xdim_matches(origin_xdim, ds_xdim, dim_name)
msg = '';
keys = fieldnames(ds_xdim);
if ~isequal(sort(keys), sort(fieldnames(origin_xdim)))
msg = sprintf(['.a.%s key mismatch between dataset '...
'and neighborhood'], dim_name);
return
end
for k = 1:numel(keys)
key = keys{k};
origin_v = origin_xdim.(key);
ds_v = ds_xdim.(key);
if ~(iscell(origin_v) && iscell(ds_v))
msg = sprintf('.a.%s ''%s'' must be a cell', ...
dim_name, key);
return
end
if numel(origin_v) ~= numel(ds_v)
msg = sprintf(['.a.%s size mismatch between ', ...
'dataset and neighborhood'], ...
dim_name, key);
return
end
for j = 1:numel(origin_v)
if ~(cosmo_isequaln(origin_v{j}, ds_v{j}) || ...
cosmo_isequaln(origin_v{j}, ds_v{j}'))
msg = sprintf(['.a.%s ''%s'' value mismatch '...
'between dataset and neighborhood'], ...
dim_name, key);
return
end
end
end
function [raise, ds, show_warning] = process_input(varargin)
raise = true;
ds = [];
show_warning = true;
narg = numel(varargin);
k = 0;
while k < narg
k = k + 1;
arg = varargin{k};
if islogical(arg)
raise = arg;
elseif isstruct(arg)
ds = arg;
cosmo_check_dataset(ds);
elseif ischar(arg)
if k == narg
error('missing argument after ''show_warning''');
end
switch arg
case 'show_warning'
k = k + 1;
show_warning = varargin{k};
otherwise
error('illegal argument at position %d', k);
end
end
end