cosmo searchlight

function results_map = cosmo_searchlight(ds, nbrhood, measure, varargin)
    %  Generic searchlight function returns a map of results computed at each
    %  searchlight location
    %
    %   results_map=cosmo_searchlight(dataset, nbrhood, measure, ...)
    %
    % Inputs:
    %   ds                   dataset struct with field .samples (NSxNF)
    %   nbrhood              Neighborhood structure with fields:
    %         .a               struct with dataset attributes
    %         .fa              struct with feature attributes. Each field
    %                            should have NF values in the second dimension
    %         .neighbors       cell with NF mappings from center_ids in output
    %                        dataset to feature ids in input dataset.
    %                        Suitable neighborhood structs can be generated
    %                        using:
    %                        - cosmo_spherical_neighborhood (fmri volume)
    %                        - cosmo_surficial_neighborhood (fmri surface)
    %                        - cosmo_meeg_chan_neigborhood (MEEG channels)
    %                        - cosmo_interval_neighborhood (MEEG time, freq)
    %                        - cosmo_cross_neighborhood (to cross neighborhoods
    %                                                    from the neighborhood
    %                                                    functions above)
    %   measure              function handle to a dataset measure. A dataset
    %                        measure has the function signature:
    %                          output = measure(dataset, args)
    %                        where output must be a struct with fields .samples
    %                        (as a column vector) and optionally a field .sa
    %                        with sample attributes.
    %                        Typical measures are:
    %                        - cosmo_correlation_measure
    %                        - cosmo_crossvalidation_measure
    %                        - cosmo_target_dsm_corr_measure
    %   'center_ids', ids    vector indicating center ids to be used as a
    %                        searchlight center. By default all feature ids are
    %                        used (i.e. ids=1:numel(nbrhood.neighbors). The
    %                        output results_map.samples has size N in the 2nd
    %                        dimension.
    %   'progress', p        Show progress every p steps
    %   'nproc', np          If the Matlab parallel processing toolbox, or the
    %                        GNU Octave parallel package is available, use
    %                        np parallel threads. (Multiple threads may speed
    %                        up searchlight computations).
    %                        If parallel processing is not available, or if
    %                        this option is not provided, then a single thread
    %                        is used.
    %   K, V                 any key-value pair (K,V) with arguments for the
    %                        measure function handle. Alternatively a struct
    %                        can be used
    %
    % Output:
    %   results_map          a dataset struct where the samples
    %                        contain the results of the searchlight analysis.
    %                        If measure returns datasets all of size Nx1 and
    %                        there are M center_ids
    %                        (M=numel(nbrhood.neighbors) if center_ids is not
    %                        provided), then results_map.samples has size MxN.
    %                        If nbrhood has fields .a and .fa, these are part
    %                        of the output (with .fa sliced according to
    %                        center_ids)
    %
    % Example:
    %     % use a minimal dataset with 6 voxels
    %     ds=cosmo_synthetic_dataset('nchunks',5);
    %     %
    %     % define neighborhood (progress is set to false to suppress output)
    %     radius=1; % radius=3 is typical for fMRI datasets
    %     nbrhood=cosmo_spherical_neighborhood(ds,'radius',radius,...
    %                                               'progress',false);
    %     %
    %     % define measure and its arguments; here crossvalidation with LDA
    %     % classifier to compute classification accuracies
    %     args=struct();
    %     args.classifier = @cosmo_classify_lda;
    %     args.partitions = cosmo_nfold_partitioner(ds);
    %     measure=@cosmo_crossvalidation_measure;
    %     %
    %     % run searchlight (without showing progress bar)
    %     result=cosmo_searchlight(ds,nbrhood,measure,'progress',0,args);
    %     %
    %     % show results:
    %     % - .samples contains classification accuracy
    %     % - .fa.nvoxels is the number of voxels in each searchlight
    %     % - .fa.radius is the radius of each searchlight
    %     cosmo_disp(result)
    %     %|| .a
    %     %||   .fdim
    %     %||     .labels
    %     %||       { 'i'  'j'  'k' }
    %     %||     .values
    %     %||       { [ 1         2         3 ]  [ 1         2 ]  [ 1 ] }
    %     %||   .vol
    %     %||     .mat
    %     %||       [ 2         0         0        -3
    %     %||         0         2         0        -3
    %     %||         0         0         2        -3
    %     %||         0         0         0         1 ]
    %     %||     .dim
    %     %||       [ 3         2         1 ]
    %     %||     .xform
    %     %||       'scanner_anat'
    %     %|| .fa
    %     %||   .nvoxels
    %     %||     [ 3         4         3         3         4         3 ]
    %     %||   .radius
    %     %||     [ 1         1         1         1         1         1 ]
    %     %||   .center_ids
    %     %||     [ 1         2         3         4         5         6 ]
    %     %||   .i
    %     %||     [ 1         2         3         1         2         3 ]
    %     %||   .j
    %     %||     [ 1         1         1         2         2         2 ]
    %     %||   .k
    %     %||     [ 1         1         1         1         1         1 ]
    %     %|| .samples
    %     %||   [ 1         1         1       0.9         1       0.7 ]
    %     %|| .sa
    %     %||   .labels
    %     %||     { 'accuracy' }
    %
    % Notes:
    %   - neighborhoods can be defined using one or more of the
    %     cosmo_*_neighborhood functions
    %
    % See also: cosmo_correlation_measure,
    %           cosmo_crossvalidation_measure,
    %           cosmo_dissimilarity_matrix_measure,
    %           cosmo_spherical_neighborhood,cosmo_surficial_neighborhood,
    %           cosmo_meeg_chan_neigborhood, cosmo_interval_neighborhood
    %           cosmo_cross_neighborhood
    %
    % #   For CoSMoMVPA's copyright information and license terms,   #
    % #   see the COPYING file distributed with CoSMoMVPA.           #

    sl_defaults = struct();
    sl_defaults.center_ids = [];
    sl_defaults.progress = 1 / 50;
    sl_defaults.nproc = 1;

    % get options for the searchlight function
    sl_opt = cosmo_structjoin(sl_defaults, varargin);
    check_input(ds, nbrhood, measure, sl_opt);

    % get options for the measure. These are all additional arguments,
    % except that progress is set to false and center_ids is removed.
    measure_opt = rmfield(sl_opt, fieldnames(sl_defaults));
    measure_opt.progress = false;

    % get the neighborhood information. This is a cell where
    % neighbors{k} contains the feature indices in input dataset 'ds'
    % for the 'k'-th center of the output dataset
    neighbors = nbrhood.neighbors;

    % get center ids
    center_ids = sl_opt.center_ids;
    if isempty(center_ids)
        center_ids = 1:numel(neighbors); % all output features
    end

    % get number of processes for searchlight
    nproc_available = cosmo_parallel_get_nproc_available(sl_opt);

    % split neighborhood in multiple parts, so that each thread can do a
    % subset of all the work
    nbrhood_cell = split_nbrhood_for_workers(nbrhood, center_ids, ...
                                             nproc_available);

    % if we have more processes than parts, run on limited threads
    nproc_used = min(numel(nbrhood_cell), nproc_available);

    % Matlab needs newline character at progress message to show it in
    % parallel mode; Octave should not have newline character
    environment = cosmo_wtf('environment');
    progress_suffix = get_progress_suffix(environment);

    % set options for each worker process
    worker_opt_cell = cell(1, nproc_used);
    for p = 1:nproc_used
        worker_opt = struct();
        worker_opt.ds = ds;
        worker_opt.measure = measure;
        worker_opt.measure_opt = measure_opt;
        worker_opt.worker_id = p;
        worker_opt.nworkers = nproc_used;
        worker_opt.progress = sl_opt.progress;
        worker_opt.progress_suffix = progress_suffix;
        worker_opt.nbrhood = nbrhood_cell{p};

        worker_opt_cell{p} = worker_opt;
    end

    % Run process for each worker in parallel
    % Note that when using nproc=1, cosmo_parcellfun does actually not
    % use any parallelization; the result is a cell with a single element.
    result_map_cell = cosmo_parcellfun(sl_opt.nproc, ...
                                       @run_searchlight_with_worker, ...
                                       worker_opt_cell, ...
                                       'UniformOutput', false);

    results_map = cosmo_stack(result_map_cell, 2);
    cosmo_check_dataset(results_map);

function results_map = run_searchlight_with_worker(worker_opt)
    % run searchlight using the options in worker_opt
    ds = worker_opt.ds;
    nbrhood = worker_opt.nbrhood;
    measure = worker_opt.measure;
    measure_opt = worker_opt.measure_opt;
    worker_id = worker_opt.worker_id;
    nworkers = worker_opt.nworkers;
    progress = worker_opt.progress;
    progress_suffix = worker_opt.progress_suffix;

    neighbors = nbrhood.neighbors;

    % allocate space for output. res_cell contains the output
    % of the measure applied to each group of features defined in
    % nbrhood. Afterwards the elements in res_cell are combined.
    ncenters = numel(nbrhood.neighbors);
    res_cell = cell(ncenters, 1);

    % see if progress is to be reported
    show_progress = ~isempty(progress) && ...
                        progress && ...
                        worker_id == 1;
    if show_progress
        progress_step = progress;
        if progress_step < 1
            progress_step = ceil(ncenters * progress_step);
        end
        prev_progress_msg = '';
        clock_start = clock();
    end

    % if measure gave the wrong result one wants to know sooner rather than
    % later. here only the first result is checked. (other errors may only
    % be caught after this 'for'-loop)
    % this is a compromise between execution speed and error reporting.
    checked_first_output = false;

    % Core searchlight code.
    % For each center_id:
    % - get the indices of its neighbors
    % - slice the dataset "ds" using these indices
    % - apply the measure to this sliced dataset with its arguments "args"
    % - store the result in "res"
    %
    for center_id = 1:ncenters
        neighbor_feature_ids = neighbors{center_id};

        % slice the dataset (with disabled kosherness-check for every
        % but the first neighborhood)
        sphere_ds = cosmo_slice(ds, neighbor_feature_ids, 2, ...
                                ~checked_first_output);

        % apply the measure
        % (try/catch can be used to provide an error message indicating
        %  which feature id caused an error)
        try
            res = measure(sphere_ds, measure_opt);

            % for efficiency, only check first output
            if ~checked_first_output
                checked_first_output = true;

                % optimization to switch off checking the partitions,
                % because they don't change for different searchlights
                measure_opt.check_partitions = false;

                cosmo_check_dataset(res);
                if size(res.samples, 2) ~= 1
                    error('Measure output must yield a column vector');
                end
            end
        catch
            caught_error = lasterror();
            caught_error.message = sprintf(['In %s, center id %d caused '...
                                            'an exception:\n%s'], ...
                                           mfilename(), ...
                                           center_id, ...
                                           caught_error.message);
            rethrow(caught_error);
        end

        res_cell{center_id} = res;

        % show progress
        if show_progress && (center_id < 10 || ...
                             ~mod(center_id, progress_step) || ...
                             center_id == ncenters)
            if nworkers > 1
                if center_id == ncenters
                    % other workers may be slower than first worker
                    msg = sprintf(['worker %d has completed; waiting for '...
                                   'other workers to finish...%s'], ...
                                  worker_id, progress_suffix);
                else
                    % can only show progress from a single worker;
                    % therefore show progress of first worker
                    msg = sprintf('for worker %d / %d%s', worker_id, ...
                                  nworkers, progress_suffix);
                end
            else
                % no specific message
                msg = '';
            end
            prev_progress_msg = cosmo_show_progress(clock_start, ...
                                                    center_id / ncenters, msg, prev_progress_msg);
        end
    end

    % prepare the output
    results_map = struct();

    % set dataset and feature attributes
    results_map.a = nbrhood.a;

    % slice the feature attributes
    results_map.fa = nbrhood.fa;

    % join the outputs from the measure for each searchlight position
    res_stacked = cosmo_stack(res_cell, 2);
    results_map.samples = res_stacked.samples;

    % if measure returns .sa, add those.
    if isfield(res_stacked, 'sa')
        results_map.sa = res_stacked.sa;
    end

    % if it returns sample attribute dimensions, add those
    if cosmo_isfield(res_stacked, 'a.sdim')
        results_map.a.sdim = res_stacked.a.sdim;
    end

function nbrhood_cell = split_nbrhood_for_workers(nbrhood, center_ids, nproc)
    % splits the neighborhood in multiple smaller neighborhoods that can be
    % used in parallel
    ncenters = numel(center_ids);

    block_size = ceil(ncenters / nproc);
    nproc_used = ceil(ncenters / block_size);
    nbrhood_cell = cell(nproc_used, 1);

    first = 1;
    for block = 1:nproc_used
        last = min(first + block_size - 1, ncenters);
        block_idxs = first:last;

        block_center_ids = center_ids(block_idxs);

        block_nbrhood = struct();
        block_nbrhood.neighbors = nbrhood.neighbors(block_center_ids);
        block_nbrhood.a = nbrhood.a;
        block_nbrhood.fa = cosmo_slice(nbrhood.fa, block_center_ids, 2, ...
                                       'struct');
        block_nbrhood.fa.center_ids = block_center_ids(:)';

        nbrhood_cell{block} = block_nbrhood;

        first = last + 1;
    end

function check_input(ds, nbrhood, measure, opt)
    if isa(nbrhood, 'function_handle') || ...
                isfield(opt, 'args') || ...
                ~isa(measure, 'function_handle')
        raise_parameter_exception();
    end

    nproc = opt.nproc;
    if ~(isnumeric(nproc) && ...
         isscalar(nproc) && ...
         round(nproc) == nproc && ...
         nproc >= 1)
        error('nproc must be positive scalar');
    end

    cosmo_check_dataset(ds);
    cosmo_check_neighborhood(nbrhood, ds);

function raise_parameter_exception()
    error(['Illegal syntax, use:\n\n', ...
           '  %s(ds,nbrhood,measure,...)\n\n', ...
           'where \n', ...
           '- ds is a dataset struct\n', ...
           '- nbrhood is a neighborhood struct\n', ...
           '- measure is a function handle of a dataset measure\n', ...
           '- any arguments to measure can be given at the ''...''\n', ...
           '  position, or as a struct\n', ...
           '(Note: as of January 2015 the syntax for this function\n'...
           'has changed. The neighborhood argument is now a fixed\n'...
           'parameter, and measure arguments are passed directly\n'...
           'rather than through an ''args'' arguments'], ...
          mfilename());

function suffix = get_progress_suffix(environment)
    % Matlab needs newline character at progress message to show it in
    % parallel mode; Octave should not have newline character

    switch environment
        case 'matlab'
            suffix = sprintf('\n');
        case 'octave'
            suffix = '';
    end