cosmo spherical neighborhood

function nbrhood=cosmo_spherical_neighborhood(ds, varargin)
% computes neighbors for a spherical searchlight
%
% nbrhood=cosmo_spherical_neighborhood(ds, opt)
%
% Inputs
%   ds                  a dataset struct, either:
%                       - in fmri form (from cosmo_fmri_dataset), when
%                         ds.fa has the fields .i, .j and .k
%                       - in meeg source form (from cosmo_meeg_dataset),
%                         when ds.fa has the field .pos. In this case, the
%                         features must have positions that can be
%                         converted to a grid.
%   'radius', r         } either use a radius of r, or select
%   'count', c          } approximately c voxels per searchlight
%                       Notes:
%                       - These two options are mutually exclusive
%                       - When using this option for an fmri dataset, the
%                         radius r is expressed in voxel units; for an meeg
%                         source dataset, the radius r is in whatever units
%                         the source dataset uses for the positions
%   'progress', p       show progress every p features (default: 1000)
%
% Outputs
%   nbrhood             dataset-like struct without .sa or .samples, with:
%     .a                dataset attributes, from dataset.a
%     .fa               feature attributes with the same fields as fs.fa,
%                       and in addition the fields:
%       .nvoxels        1xP number of voxels in each searchlight
%       .radius         1xP radius in voxel units
%       .center_ids     1xP feature center id
%     .neighbors        Px1 cell so that center2neighbors{k}==nbrs contains
%                       the feature ids of the neighbors of feature k
%                       If the dataset has a field ds.fa.inside, then
%                       features that are not inside are not included as
%                       neighbors in the output
%     .origin           Has fields .a and .fa from input dataset
%
%
% Example:
%     ds=cosmo_synthetic_dataset('type','fmri');
%     radius=1; % radius=3 is typical for 'real-world' searchlights
%     nbrhood=cosmo_spherical_neighborhood(ds,'radius',radius,...
%                                             'progress',false);
%     cosmo_disp(nbrhood)
%     %|| .a
%     %||   .fdim
%     %||     .labels
%     %||       { 'i'  'j'  'k' }
%     %||     .values
%     %||       { [ 1         2         3 ]  [ 1         2 ]  [ 1 ] }
%     %||   .vol
%     %||     .mat
%     %||       [ 2         0         0        -3
%     %||         0         2         0        -3
%     %||         0         0         2        -3
%     %||         0         0         0         1 ]
%     %||     .dim
%     %||       [ 3         2         1 ]
%     %||     .xform
%     %||       'scanner_anat'
%     %|| .fa
%     %||   .nvoxels
%     %||     [ 3         4         3         3         4         3 ]
%     %||   .radius
%     %||     [ 1         1         1         1         1         1 ]
%     %||   .center_ids
%     %||     [ 1         2         3         4         5         6 ]
%     %||   .i
%     %||     [ 1         2         3         1         2         3 ]
%     %||   .j
%     %||     [ 1         1         1         2         2         2 ]
%     %||   .k
%     %||     [ 1         1         1         1         1         1 ]
%     %|| .neighbors
%     %||   { [ 1         4         2 ]
%     %||     [ 2         1         5         3 ]
%     %||     [ 3         2         6 ]
%     %||     [ 4         1         5 ]
%     %||     [ 5         4         2         6 ]
%     %||     [ 6         5         3 ]           }
%     %|| .origin
%     %||   .a
%     %||     .fdim
%     %||       .labels
%     %||         { 'i'
%     %||           'j'
%     %||           'k' }
%     %||       .values
%     %||         { [ 1         2         3 ]
%     %||           [ 1         2 ]
%     %||           [ 1 ]                     }
%     %||     .vol
%     %||       .mat
%     %||         [ 2         0         0        -3
%     %||           0         2         0        -3
%     %||           0         0         2        -3
%     %||           0         0         0         1 ]
%     %||       .dim
%     %||         [ 3         2         1 ]
%     %||       .xform
%     %||         'scanner_anat'
%     %||   .fa
%     %||     .i
%     %||       [ 1         2         3         1         2         3 ]
%     %||     .j
%     %||       [ 1         1         1         2         2         2 ]
%     %||     .k
%     %||       [ 1         1         1         1         1         1 ]
%
%
% Notes:
%   - this function can return neighborhoods with either a fixed number of
%     features, or a fixed radius. When used with a searchlight, the
%     former has the advantage that the number of features is less
%     variable (especially near edges of the brain, in an fmri dataset),
%     which can make it easier to compare result in different regions as
%     the number of features can affect
%     pattern discriminablity. The latter has the advantage that the
%     smoothness of the output maps under the null hypothesis can be more
%     uniformly smooth.
%
% See also: cosmo_fmri_dataset, cosmo_meeg_dataset, cosmo_searchlight
%
% #   For CoSMoMVPA's copyright information and license terms,   #
% #   see the COPYING file distributed with CoSMoMVPA.           #

    check_input(varargin{:});

    defaults=struct();
    defaults.progress=1000;
    opt=cosmo_structjoin(defaults,varargin);

    [use_fixed_radius,radius,voxel_count]=get_selection_params(opt);
    cosmo_check_dataset(ds);

    % ensure not too many features are requested
    feature_mask=get_features_mask(ds);
    nfeatures=sum(feature_mask);
    if nfeatures<voxel_count
        error('Cannot select %d features: only %d are present',...
                    voxel_count, nfeatures);
    end

    % get attributes for output dataset, and the positions and dimension of
    % the grid
    [fdim,fa,pos,grid_dim]=get_spherical_attributes(ds,feature_mask);

    % compute voxel offsets relative to origin
    [sphere_offsets, distances]=get_sphere_offsets(radius);

    % get mapping from linear ids to feature ids
    [lin2feature_ids,lin2feature_mask]=get_lin2feature_ids(grid_dim,pos...
                                        ,feature_mask);

    % number of features associated with each linear id
    feature_id_count=sum(lin2feature_mask,2);

    show_progress=opt.progress>0;

    if show_progress
        clock_start=clock();
        prev_progress_msg='';
    end

    % a position may occur at multiple features; only consider unique
    % positions
    pos(:,~feature_mask)=Inf;
    [center_idxs,unq_pos]=cosmo_index_unique(pos');
    keep_unq_pos=~any(isinf(unq_pos),2);
    center_idxs=center_idxs(keep_unq_pos);
    unq_pos=unq_pos(keep_unq_pos,:);
    nunq_centers=numel(center_idxs);

    % allocate space for output
    ncenters=nunq_centers;
    neighbors=cell(ncenters,1);
    nvoxels=zeros(1,ncenters);
    final_radius=zeros(1,ncenters);
    visited=false(1,ncenters);
    center_ids=zeros(1,ncenters);

    % go over all features
    for k=1:nunq_centers
        variable_radius=NaN;
        if voxel_count==0
            feature_ids=zeros(1,0);
        else
            center_pos=unq_pos(k,:)';

            % - in case of a variable radius, keep growing sphere_offsets
            %   until there are enough voxels selected. This new radius is
            %   kept for every subsequent iteration.
            % - in case of a fixed radius this loop is left after the first
            %   iteration.
            while true
                % add offsets to center
                all_around_pos=bsxfun(@plus, center_pos', sphere_offsets);

                % see which ones are outside the volume
                outside_msk=all_around_pos<=0 | ...
                                bsxfun(@minus,grid_dim,all_around_pos)<0;

                % collapse over 3 dimensions
                feature_outside_msk=any(outside_msk,2);

                % get rid of those outside the volume
                around_pos=all_around_pos(~feature_outside_msk,:);


                % convert to linear indices
                around_lin=fast_sub2ind(grid_dim,around_pos(:,1), ...
                                            around_pos(:,2), ...
                                            around_pos(:,3));

                % convert linear to feature ids
                % (transpose is necessary so that when applying the
                %  mask, the indices remain sorted by distance)
                around_ids_mat=lin2feature_ids(around_lin,:)';
                around_ids_mask=lin2feature_mask(around_lin,:)';
                feature_ids=around_ids_mat(around_ids_mask);

                if use_fixed_radius
                    break; % we're done selecting voxels
                elseif numel(feature_ids)<voxel_count
                    % the current radius is too small.
                    % increase the radius by half a voxel and recompute new
                    % offsets, then try again in the next iteration.
                    radius=radius+.5;
                    [sphere_offsets,distances]=get_sphere_offsets(radius);
                    continue
                end


                % if using variable radius, compute distance of each linear
                % index
                center_distances=distances(~feature_outside_msk);

                % get distance for each feature
                feature_distances=get_distances(center_distances,...
                                            feature_id_count(around_lin));

                % coming here, the radius is variable and enough features
                % were selected. Now decide which voxels to keep,
                % and also compute the metric radius, then leave the while
                % loop.

                nselect=boundary_at_approx(feature_ids,...
                                            feature_distances,voxel_count);
                feature_ids=feature_ids(1:nselect);

                variable_radius=feature_distances(nselect);
                break; % we're done
            end
        end


        % store results
        id=center_idxs{k}(1);


        neighbors{k}=feature_ids(:)';
        nvoxels(k)=numel(feature_ids);
        if use_fixed_radius
            final_radius(k)=radius;
        else
            final_radius(k)=variable_radius;
        end

        visited(k)=true;
        assert(center_ids(k)==0);
        center_ids(k)=id;

        if show_progress && (k==1 || k==nunq_centers || ...
                                        mod(k,opt.progress)==0)
            mean_size=mean(nvoxels(visited));
            msg=sprintf('mean size %.1f', mean_size);
            prev_progress_msg=cosmo_show_progress(clock_start, ...
                                   k/nunq_centers, msg, prev_progress_msg);
        end
    end

    not_visited_ids=find(~visited);
    assert(all(cellfun(@numel,neighbors(not_visited_ids))==0));
    neighbors(not_visited_ids)=repmat({zeros(1,0)},...
                                        1,numel(not_visited_ids));


    % set the dataset and feature attributes
    nbrhood=struct();
    nbrhood.a=ds.a;
    nbrhood.a.fdim=fdim;

    % remove sample dimension if present
    if isfield(nbrhood.a,'sdim')
        nbrhood.a=rmfield(nbrhood.a,'sdim');
    end


    fa_full=cosmo_slice(fa,center_ids,2,'struct');
    nbrhood.fa=cosmo_structjoin('nvoxels',nvoxels,...
                                'radius',final_radius,...
                                'center_ids',center_ids(:)',...
                                fa_full);

    nbrhood.neighbors=neighbors;

    nbrhood=align_nbrhood_to_ds_if_possible(ds,nbrhood);
    origin=struct();
    origin.a=ds.a;
    origin.fa=ds.fa;
    nbrhood.origin=origin;

    cosmo_check_neighborhood(nbrhood,ds);



function nbrhood=align_nbrhood_to_ds_if_possible(ds,nbrhood)
   labels=get_dim_label(ds);

   ds_fa=get_spherical_fa_cell(ds.fa,labels);
   nbrhood_fa=get_spherical_fa_cell(nbrhood.fa,labels);

   [unq_ds,idx_ds]=cosmo_index_unique(ds_fa);
   [unq_nbrhood,idx_nbrhood]=cosmo_index_unique(nbrhood_fa);

   if all(cellfun(@numel,unq_ds)==1) && ...
           isequal(sort(cell2mat(unq_ds)),sort(cell2mat(unq_nbrhood)))
       mp=cosmo_align(nbrhood_fa,ds_fa);

       nbrhood.neighbors=nbrhood.neighbors(mp);
       nbrhood.fa=cosmo_slice(nbrhood.fa,mp,2,'struct');
   end


function feature_distances=get_distances(center_distances,feature_id_count)
    % get distances based on selected features
    n=numel(center_distances);
    assert(n==numel(feature_id_count));

    m=max(feature_id_count);
    if m<=1
        % optimization
        feature_distances=center_distances(feature_id_count==1);
        return
    end

    ds=NaN(m,n);

    for k=1:m
        msk=feature_id_count>=k;
        ds(k,msk)=center_distances(msk);
    end

    feature_distances=ds(~isnan(ds));


function [lin2feature_ids,lin2feature_mask]=get_lin2feature_ids(...
                                            grid_dim,all_pos,center_mask)
    % returns a function that maps linear ids to feature ids
    % the function takes as input linear ids and the distance for each
    % linear id, and returns the feature ids and their corresponding
    % distances

    orig_nvoxels=prod(grid_dim);

    ijk=all_pos(:,center_mask);

    lin_ids=fast_sub2ind(grid_dim, ijk(1,:), ijk(2,:), ijk(3,:));
    [idxs,unq_lin_ids]=cosmo_index_unique(lin_ids');

    mask2full=find(center_mask);
    % lin2feature_ids{k}={i1,...,iN} means that the linear voxel index k
    % corresponds to features i1,...iN
    lin2feature_ids_cell=cell(orig_nvoxels,1);
    for k=1:numel(unq_lin_ids)
        lin_id=unq_lin_ids(k);
        idx=idxs{k}(:)';
        lin2feature_ids_cell{lin_id}=mask2full(idx);
    end

    n_max=max(cellfun(@numel,lin2feature_ids_cell));

    lin2feature_ids=zeros(orig_nvoxels, n_max);
    lin2feature_mask=false(orig_nvoxels,n_max);

    for k=1:numel(unq_lin_ids)
        lin_id=unq_lin_ids(k);
        indices=lin2feature_ids_cell{lin_id};

        cols=1:numel(indices);
        lin2feature_ids(lin_id,cols)=indices;
        lin2feature_mask(lin_id,cols)=true;
    end






function feature_mask=get_features_mask(ds)
    % use .fa.inside if it is present, otherwise an array with only true
    % values
    nfeatures=size(ds.samples,2);

    if cosmo_isfield(ds,'fa.inside')
        inside=ds.fa.inside;

        if size(inside,1)~=1
            error('field .fa.inside must be a row vector');
        end

        if ~islogical(inside)
            error('field .fa.inside must be logical');
        end

        feature_mask=inside;
    else
        feature_mask=true(1,nfeatures);
    end


function lin=fast_sub2ind(sz, i, j, k)
    lin=sz(1)*(sz(2)*(k-1)+(j-1))+i;

function pos=boundary_at_approx(ids, distances, voxel_count)
    % pseudo-random selection of approximatly voxel_count elements
    if voxel_count<=0
        pos=0;
        return
    end

    assert(issorted(distances));

    max_distance=distances(voxel_count);
    first=find(distances<max_distance,1,'last')+1;
    last=find(distances>max_distance,1,'first')-1;

    if isempty(first)
        first=1;
    end

    if isempty(last)
        last=numel(distances);
    end

    delta_first=voxel_count-first;
    delta_last=last-voxel_count;

    if delta_first==delta_last
        % select pseudo-randomly
        if delta_first==0 || mod(sum(ids)+numel(distances),2)==0
            pos=first;
        else
            pos=last;
        end
    elseif delta_first<delta_last
        pos=first;
    else
        pos=last;
    end

    assert(first==1 || distances(first-1)<distances(first));
    assert(last==numel(distances) || distances(last+1)>distances(first));


function [fdim,fa,ijk,orig_dim]=get_spherical_attributes(ds, center_mask)
    % returns fdim, fa, and ijk positions for dataset
    labels=get_dim_label(ds);

    fdim=get_spherical_fdim(ds,labels);
    fa=get_spherical_fa(ds.fa,labels);

    if cosmo_isfield(ds,'fa.inside')
        fa.inside=center_mask;
    end

    small_ds=cosmo_slice(ds,[],1);
    small_ds_vol=cosmo_vol_grid_convert(small_ds, 'tovol');

    ijk=[small_ds_vol.fa.i;small_ds_vol.fa.j;small_ds_vol.fa.k];

    ijk_labels={'i','j','k'};
    [unused,index]=has_fdim_label(small_ds_vol,ijk_labels);

    orig_dim=cellfun(@numel,small_ds_vol.a.fdim.values(index));
    orig_dim=orig_dim(:)';


function [tf,index]=has_fdim_label(ds, label)
    [two,index]=cosmo_dim_find(ds,label,false);
    tf=~isempty(two) && two==2;


function [labels,index]=get_dim_label(ds)
    % get either pos or i, j, and k labels
    possible_labels={{'pos'},{'i';'j';'k'}};
    for j=1:numel(possible_labels)
        labels=possible_labels{j};
        [has_label,index]=has_fdim_label(ds, labels);
        if has_label
            return
        end
    end

    error(['Unable to find dimension labels, either ''pos'' '...
                    'or ''i'', ''j'', and ''k''']);


function fdim=get_spherical_fdim(ds, target_labels)
    first_target_label=target_labels{1};
    [two, index]=cosmo_dim_find(ds,first_target_label,true);

    if two~=2
        error('dimension ''%s'' must be a feature dimension');
    end
    cosmo_isfield(ds,'a.fdim.labels',true);

    dim_labels=ds.a.fdim.labels(:);
    dim_values=ds.a.fdim.values(:);

    nlabels=numel(target_labels);
    idx_labels=(index+(0:(nlabels-1)))';
    if numel(dim_labels)<index+(nlabels-1) || ...
            ~isequal(dim_labels(idx_labels),target_labels)
        error('expected labels %s in .a.fdim.labels(%d:%d)',...
                  cosmo_strjoin(target_labels,', '),...
                idx_labels(1), idx_labels(end));
    end

    fdim=struct();
    fdim.labels=dim_labels(idx_labels);
    fdim.values=dim_values(idx_labels);

    fdim=ensure_row_vector_or_3d_matrix(fdim);

function fdim=ensure_row_vector_or_3d_matrix(fdim)
    labels=fdim.labels;
    nlabels=numel(labels);

    keys={'labels','values'};
    nkeys=numel(keys);
    for k=1:nlabels
        label=labels{k};
        for j=1:nkeys
            key=keys{j};
            value=fdim.(key){k};

            if strcmp(label,'pos') && strcmp(key,'values')
                if size(value,1)~=3
                    error(['''pos'' attribute in .a.fdim.values '...
                                'must be 3xM']);
                end
            else
                if ~isvector(value)
                    error(['''%s'' attribute in .a.fdim.%s must '...
                            'be a vector'],labels,key);
                end
                fdim.(key){k}=value(:)';
            end
        end
    end







function fa=get_spherical_fa(ds_fa, target_labels)
    fa_cell=get_spherical_fa_cell(ds_fa, target_labels);
    fa=cell2struct(fa_cell,target_labels,2);


function fa_cell=get_spherical_fa_cell(ds_fa, target_labels)
    nlabels=numel(target_labels);
    fa_cell=cell(1,nlabels);
    for j=1:nlabels
        label=target_labels{j};
        fa_cell{j}=ds_fa.(label);
    end


function [sphere_offsets, o_distances]=get_sphere_offsets(radius)
    % return offsets and euclidean (and a bit manhattan) distance
    % from origin
    [sphere_offsets, norm2_distances]=cosmo_sphere_offsets(radius);

    % compute manhattan distance
    norm1_distances=sum(abs(sphere_offsets),2);

    % add a tiny bit of manhattan to make distances more varied
    norm12_distances=norm2_distances+1e-5*norm1_distances;

    % ensure distances are sorted
    [o_distances,i]=sort(norm12_distances);
    sphere_offsets=sphere_offsets(i,:);


function check_input(varargin)
    if numel(varargin)<1 || isscalar(varargin{1})
        % change in parameters
        raise_parameter_error();
    end

function [use_fixed_radius,radius,voxel_count]=get_selection_params(opt)
    if isfield(opt,'radius')
        if isfield(opt,'count')
            raise_parameter_error();
        elseif isscalar(opt.radius) && opt.radius>=0
            use_fixed_radius=true;
            radius=opt.radius;
            voxel_count=NaN;
            return
        end
    elseif isfield(opt,'count') && isscalar(opt.count) && ...
                opt.count>=0 && round(opt.count)==opt.count
        use_fixed_radius=false;
        radius=1; % starting point
        voxel_count=opt.count;
        return;
    end

    raise_parameter_error();


function raise_parameter_error()
    name=mfilename();
    error(['Illegal parameters, use one of:\n',...
        '- %s(...,''radius'',r) to use a radius of r voxels\n',...
        '- %s(...,''count'',c) to select c voxels per searchlight\n',...
        '(As of January 2014 the syntax of this function has changed)'],...
            name,name);