cosmo searchlight skl

function results_map = cosmo_searchlight(ds, nbrhood, measure, varargin)
%  Generic searchlight function returns a map of results computed at each
%  searchlight location
%
%   results_map=cosmo_searchlight(dataset, nbrhood, measure, ...)
%
% Inputs:
%   ds                   dataset struct with field .samples (NSxNF)
%   nbrhood              Neighborhood structure with fields:
%         .a               struct with dataset attributes
%         .fa              struct with feature attributes. Each field
%                            should have NF values in the second dimension
%         .neighbors       cell with NF mappings from center_ids in output
%                        dataset to feature ids in input dataset.
%                        Suitable neighborhood structs can be generated
%                        using:
%                        - cosmo_spherical_neighborhood (fmri volume)
%                        - cosmo_surficial_neighborhood (fmri surface)
%                        - cosmo_meeg_chan_neigborhood (MEEG channels)
%                        - cosmo_interval_neighborhood (MEEG time, freq)
%                        - cosmo_cross_neighborhood (to cross neighborhoods
%                                                    from the neighborhood
%                                                    functions above)
%   measure              function handle to a dataset measure. A dataset
%                        measure has the function signature:
%                          output = measure(dataset, args)
%                        where output must be a struct with fields .samples
%                        (as a column vector) and optionally a field .sa
%                        with sample attributes.
%                        Typical measures are:
%                        - cosmo_correlation_measure
%                        - cosmo_crossvalidation_measure
%                        - cosmo_target_dsm_corr_measure
%   'center_ids', ids    vector indicating center ids to be used as a
%                        searchlight center. By default all feature ids are
%                        used (i.e. ids=1:numel(nbrhood.neighbors). The
%                        output results_map.samples has size N in the 2nd
%                        dimension.
%   'progress', p        Show progress every p steps
%   'nproc', np          If the Matlab parallel processing toolbox, or the
%                        GNU Octave parallel package is available, use
%                        np parallel threads. (Multiple threads may speed
%                        up searchlight computations).
%                        If parallel processing is not available, or if
%                        this option is not provided, then a single thread
%                        is used.
%   K, V                 any key-value pair (K,V) with arguments for the
%                        measure function handle. Alternatively a struct
%                        can be used
%
% Output:
%   results_map          a dataset struct where the samples
%                        contain the results of the searchlight analysis.
%                        If measure returns datasets all of size Nx1 and
%                        there are M center_ids
%                        (M=numel(nbrhood.neighbors) if center_ids is not
%                        provided), then results_map.samples has size MxN.
%                        If nbrhood has fields .a and .fa, these are part
%                        of the output (with .fa sliced according to
%                        center_ids)
%
% Example:
%     % use a minimal dataset with 6 voxels
%     ds=cosmo_synthetic_dataset('nchunks',5);
%     %
%     % define neighborhood (progress is set to false to suppress output)
%     radius=1; % radius=3 is typical for fMRI datasets
%     nbrhood=cosmo_spherical_neighborhood(ds,'radius',radius,...
%                                               'progress',false);
%     %
%     % define measure and its arguments; here crossvalidation with LDA
%     % classifier to compute classification accuracies
%     args=struct();
%     args.classifier = @cosmo_classify_lda;
%     args.partitions = cosmo_nfold_partitioner(ds);
%     measure=@cosmo_crossvalidation_measure;
%     %
%     % run searchlight (without showing progress bar)
%     result=cosmo_searchlight(ds,nbrhood,measure,'progress',0,args);
%     %
%     % show results:
%     % - .samples contains classification accuracy
%     % - .fa.nvoxels is the number of voxels in each searchlight
%     % - .fa.radius is the radius of each searchlight
%     cosmo_disp(result)
%     %|| .a
%     %||   .fdim
%     %||     .labels
%     %||       { 'i'  'j'  'k' }
%     %||     .values
%     %||       { [ 1         2         3 ]  [ 1         2 ]  [ 1 ] }
%     %||   .vol
%     %||     .mat
%     %||       [ 2         0         0        -3
%     %||         0         2         0        -3
%     %||         0         0         2        -3
%     %||         0         0         0         1 ]
%     %||     .dim
%     %||       [ 3         2         1 ]
%     %||     .xform
%     %||       'scanner_anat'
%     %|| .fa
%     %||   .nvoxels
%     %||     [ 3         4         3         3         4         3 ]
%     %||   .radius
%     %||     [ 1         1         1         1         1         1 ]
%     %||   .center_ids
%     %||     [ 1         2         3         4         5         6 ]
%     %||   .i
%     %||     [ 1         2         3         1         2         3 ]
%     %||   .j
%     %||     [ 1         1         1         2         2         2 ]
%     %||   .k
%     %||     [ 1         1         1         1         1         1 ]
%     %|| .samples
%     %||   [ 1         1         1       0.9         1       0.7 ]
%     %|| .sa
%     %||   .labels
%     %||     { 'accuracy' }
%
% Notes:
%   - neighborhoods can be defined using one or more of the
%     cosmo_*_neighborhood functions
%
% See also: cosmo_correlation_measure,
%           cosmo_crossvalidation_measure,
%           cosmo_dissimilarity_matrix_measure,
%           cosmo_spherical_neighborhood,cosmo_surficial_neighborhood,
%           cosmo_meeg_chan_neigborhood, cosmo_interval_neighborhood
%           cosmo_cross_neighborhood
%
% #   For CoSMoMVPA's copyright information and license terms,   #
% #   see the COPYING file distributed with CoSMoMVPA.           #

    sl_defaults=struct();
    sl_defaults.center_ids=[];
    sl_defaults.progress=1/50;
    sl_defaults.nproc=1;

    % get options for the searchlight function
    sl_opt=cosmo_structjoin(sl_defaults,varargin);
    check_input(ds,nbrhood,measure,sl_opt);

    % get options for the measure. These are all additional arguments,
    % except that progress is set to false and center_ids is removed.
    measure_opt=rmfield(sl_opt,fieldnames(sl_defaults));
    measure_opt.progress=false;

    % get the neighborhood information. This is a cell where
    % neighbors{k} contains the feature indices in input dataset 'ds'
    % for the 'k'-th center of the output dataset
    neighbors=nbrhood.neighbors;

    % get center ids
    center_ids=sl_opt.center_ids;
    if isempty(center_ids)
        center_ids=1:numel(neighbors); % all output features
    end

    % get number of processes for searchlight
    nproc_available=cosmo_parallel_get_nproc_available(sl_opt);

    % split neighborhood in multiple parts, so that each thread can do a
    % subset of all the work
    nbrhood_cell=split_nbrhood_for_workers(nbrhood,center_ids,...
                                                    nproc_available);

    % if we have more processes than parts, run on limited threads
    nproc_used=min(numel(nbrhood_cell),nproc_available);

    % Matlab needs newline character at progress message to show it in
    % parallel mode; Octave should not have newline character
    environment=cosmo_wtf('environment');
    progress_suffix=get_progress_suffix(environment);

    % set options for each worker process
    worker_opt_cell=cell(1,nproc_used);
    for p=1:nproc_used
        worker_opt=struct();
        worker_opt.ds=ds;
        worker_opt.measure=measure;
        worker_opt.measure_opt=measure_opt;
        worker_opt.worker_id=p;
        worker_opt.nworkers=nproc_used;
        worker_opt.progress=sl_opt.progress;
        worker_opt.progress_suffix=progress_suffix;
        worker_opt.nbrhood=nbrhood_cell{p};

        worker_opt_cell{p}=worker_opt;
    end

    % Run process for each worker in parallel
    % Note that when using nproc=1, cosmo_parcellfun does actually not
    % use any parallellization; the result is a cell with a single element.
    result_map_cell=cosmo_parcellfun(sl_opt.nproc,...
                                     @run_searchlight_with_worker,...
                                    worker_opt_cell,...
                                    'UniformOutput',false);

    results_map=cosmo_stack(result_map_cell,2);
    cosmo_check_dataset(results_map);

function results_map=run_searchlight_with_worker(worker_opt)
% run searchlight using the options in worker_opt
    ds=worker_opt.ds;
    nbrhood=worker_opt.nbrhood;
    measure=worker_opt.measure;
    measure_opt=worker_opt.measure_opt;
    worker_id=worker_opt.worker_id;
    nworkers=worker_opt.nworkers;
    progress=worker_opt.progress;
    progress_suffix=worker_opt.progress_suffix;

    neighbors=nbrhood.neighbors;

    % allocate space for output. res_cell contains the output
    % of the measure applied to each group of features defined in
    % nbrhood. Afterwards the elements in res_cell are combined.
    ncenters=numel(nbrhood.neighbors);
    res_cell=cell(ncenters,1);

    % see if progress is to be reported
    show_progress=~isempty(progress) && ...
                        progress && ...
                        worker_id==1;
    if show_progress
        progress_step=progress;
        if progress_step<1
            progress_step=ceil(ncenters*progress_step);
        end
        prev_progress_msg='';
        clock_start=clock();
    end

    % if measure gave the wrong result one wants to know sooner rather than
    % later. here only the first result is checked. (other errors may only
    % be caught after this 'for'-loop)
    % this is a compromise between execution speed and error reporting.
    checked_first_output=false;

    % Core searchlight code.
    % For each center_id:
    % - get the indices of its neighbors
    % - slice the dataset "ds" using these indices
    % - apply the measure to this sliced dataset with its arguments "args"
    % - store the result in "res"
    %
    for center_id=1:ncenters
        %%%% >>> Your code here <<< %%%%

        res_cell{center_id}=res;

        % show progress
        if show_progress && (center_id<10 || ...
                                ~mod(center_id,progress_step) || ...
                                center_id==ncenters)
            if nworkers>1
                if center_id==ncenters
                    % other workers may be slower than first worker
                    msg=sprintf(['worker %d has completed; waiting for '...
                                    'other workers to finish...%s'],...
                                    worker_id, progress_suffix);
                else
                    % can only show progress from a single worker;
                    % therefore show progress of first worker
                    msg=sprintf('for worker %d / %d%s', worker_id, ...
                                    nworkers, progress_suffix);
                end
            else
                % no specific message
                msg='';
            end
            prev_progress_msg=cosmo_show_progress(clock_start, ...
                            center_id/ncenters, msg, prev_progress_msg);
        end
    end

    % prepare the output
    results_map=struct();

    % set dataset and feature attributes
    results_map.a=nbrhood.a;

    % slice the feature attributes
    results_map.fa=nbrhood.fa;

    % join the outputs from the measure for each searchlight position
    res_stacked=cosmo_stack(res_cell,2);
    results_map.samples=res_stacked.samples;

    % if measure returns .sa, add those.
    if isfield(res_stacked,'sa')
        results_map.sa=res_stacked.sa;
    end

    % if it returns sample attribute dimensions, add those
    if cosmo_isfield(res_stacked, 'a.sdim')
        results_map.a.sdim=res_stacked.a.sdim;
    end


function nbrhood_cell=split_nbrhood_for_workers(nbrhood,center_ids,nproc)
% splits the neighborhood in multiple smaller neighborhoods that can be
% used in parallel
    ncenters=numel(center_ids);

    block_size=ceil(ncenters/nproc);
    nproc_used=ceil(ncenters/block_size);
    nbrhood_cell=cell(nproc_used,1);

    first=1;
    for block=1:nproc_used
        last=min(first+block_size-1,ncenters);
        block_idxs=first:last;

        block_center_ids=center_ids(block_idxs);

        block_nbrhood=struct();
        block_nbrhood.neighbors=nbrhood.neighbors(block_center_ids);
        block_nbrhood.a=nbrhood.a;
        block_nbrhood.fa=cosmo_slice(nbrhood.fa,block_center_ids,2,...
                                                            'struct');
        block_nbrhood.fa.center_ids=block_center_ids(:)';

        nbrhood_cell{block}=block_nbrhood;

        first=last+1;
    end


function check_input(ds, nbrhood, measure, opt)
    if isa(nbrhood,'function_handle') || ...
                isfield(opt,'args') || ...
                ~isa(measure,'function_handle')
        raise_parameter_exception();
    end

    nproc=opt.nproc;
    if ~(isnumeric(nproc) && ...
            isscalar(nproc) && ...
            round(nproc)==nproc && ...
            nproc>=1)
        error('nproc must be positive scalar');
    end

    cosmo_check_dataset(ds);
    cosmo_check_neighborhood(nbrhood,ds);


function raise_parameter_exception()
    error(['Illegal syntax, use:\n\n',...
            '  %s(ds,nbrhood,measure,...)\n\n',...
            'where \n',...
            '- ds is a dataset struct\n',...
            '- nbrhood is a neighborhood struct\n',...
            '- measure is a function handle of a dataset measure\n',...
            '- any arguments to measure can be given at the ''...''\n',...
            '  position, or as a struct\n',...
            '(Note: as of January 2015 the syntax for this function\n'...
            'has changed. The neighboorhood argument is now a fixed\n'...
            'parameter, and measure arguments are passed directly\n'...
            'rather than through an ''args'' arguments'], ...
            mfilename());


function suffix=get_progress_suffix(environment)
    % Matlab needs newline character at progress message to show it in
    % parallel mode; Octave should not have newline character

    switch environment
        case 'matlab'
            suffix=sprintf('\n');
        case 'octave'
            suffix='';
    end