cosmo parcellfun skl

function result=cosmo_parcellfun(nproc,func,arg_cell,varargin)
% applies a function to elements in a cell in parallel
%
% result=cosmo_parcellfun(nproc,func,arg_cell,...)
%
% Inputs:
%   nproc                   Maximum number of processes to run in parallel
%   func                    Function handle that takes a single input
%                           argument and gives a single output
%   arg_cell                Cell with arguments to be given to func
%   'UniformOutput',o_u     If false, then the output is converted to a
%                           numeric or logical array. Default: true
%
% Output:
%   result                  Cell with the same size as arg_cell, with
%                               result{i}=func(arg_cell{i})
%                           If o_u is true, then result is converted to a
%                           numeric or logical array, assuming that each
%                           output is a scalar. If any output is not a
%                           scalar while o_u is true, an exception is
%                           thrown.
%

    default=struct();
    default.UniformOutput=true;

    opt=cosmo_structjoin(default,varargin{:});

    check_input(nproc,func,arg_cell,opt);

    % see how many processes to use
    narg_cell=numel(arg_cell);

    if narg_cell==1
        nproc_to_use=1;
    else
        nproc_available=cosmo_parallel_get_nproc_available(opt);
        nproc_to_use=min([nproc,narg_cell,nproc_available]);
    end

    if nproc_to_use<=1
        helper_func=@run_single_thread;
    else
        is_matlab=cosmo_wtf('is_matlab');
        if is_matlab
            helper_func=@run_parallel_matlab;
        else
            helper_func=@run_parallel_octave;
        end
    end

    result=helper_func(nproc_to_use,func,arg_cell,opt);


function args=get_extra_builtin_cellfun_args(opt)
    if opt.UniformOutput
        args={};
    else
        args={'UniformOutput',false};
    end


function result=run_single_thread(nproc,func,arg_cell,opt)
% single thread, Matlab of Octave --- redirect to cellfun
    result_cell=cellfun(func,arg_cell,'UniformOutput',false);
    result=convert_to_uniform_output_if_necessary(result_cell,opt);


function result=run_parallel_matlab(nproc,func,arg_cell,opt)
% multi-thread, Matlab
    narg_cell=numel(arg_cell);
    result_cell=cell(size(arg_cell));

    parfor (i=1:narg_cell, nproc)
        result_cell{i}=func(arg_cell{i});
    end

    result=convert_to_uniform_output_if_necessary(result_cell,opt);

function result=convert_to_uniform_output_if_necessary(result_cell,opt)
    if opt.UniformOutput
        is_uniform_func=@(x)isequal(size(x),[1 1]);
        is_uniform_mask=cellfun(is_uniform_func, result_cell);
        if ~all(is_uniform_mask)
            error(['Not all outputs are scalar. Use\n'...
                    '    ''UniformOutput'',false\n'...
                    'to return cell output']);
        end

        % concatenate and put in original shape
        result=reshape(cat(1,result_cell{:}),size(result_cell));
    else
        result=result_cell;
    end

function result=run_parallel_octave(nproc,func,arg_cell,opt)
% multi-thread, Octave
    extra_octave_args={'VerboseLevel',0};

    cellfun_args=get_extra_builtin_cellfun_args(opt);
    result=parcellfun(nproc,func,arg_cell,...
                                cellfun_args{:},...
                                extra_octave_args{:});



function check_input(nproc,func,arg_cell,opt)
    if ~(isnumeric(nproc) && ...
            isscalar(nproc) && ...
            round(nproc)==nproc && ...
            nproc>0)
        error(['nproc must be a positive integer. Use nproc=inf to use '...
                'as many processes as there are cores available']);
    end

    if ~isa(func,'function_handle')
        error('second argument must be a function handle');
    end

    if ~iscell(arg_cell)
        error('third argument must be a cell');
    end

    if ~(islogical(opt.UniformOutput) && ...
                isscalar(opt.UniformOutput))
        error('option ''UniformOutput'' must be scalar boolean');
    end