test parallel get nproc available

function test_suite = test_parallel_get_nproc_available()
    % tests for cosmo_parallel_get_nproc_available
    %
    % #   For CoSMoMVPA's copyright information and license terms,   #
    % #   see the COPYING file distributed with CoSMoMVPA.           #
    try % assignment of 'localfunctions' is necessary in Matlab >= 2016
        test_functions = localfunctions();
    catch % no problem; early Matlab versions can use initTestSuite fine
    end
    initTestSuite;

function test_parallel_get_nproc_available_matlab_ge2013b()
    if ~cosmo_wtf('is_matlab') || ...
                version_lt2013b()
        cosmo_notify_test_skipped('Only for the Matlab platform >=2013b');
        return
    end

    has_parallel_toolbox = any(cosmo_check_external( ...
                                                    {'@distcomp', '@parallel'}, false)) && ...
                                ~isempty(which('gcp'));

    aeq = @(expected_output, varargin)assertEqual(expected_output, ...
                                                  cosmo_parallel_get_nproc_available(varargin{:}));

    warning_state = cosmo_warning();
    state_resetter = onCleanup(@()cosmo_warning(warning_state));
    cosmo_warning('off');

    if has_parallel_toolbox
        pool = gcp();

        has_no_pool = isempty(pool);

        if has_no_pool
            % start a new pool
            pool = parpool();
            pool_resetter = onCleanup(@()delete(gcp('nocreate')));
        end

        if isempty(pool)
            aeq(1);
            aeq(1, 'nproc', 2);
        else
            nproc = pool.NumWorkers();
            aeq(nproc, 'nproc', nproc);
        end

    else
        aeq(1);
        aeq(1, 'nproc', 2);
    end

function tf = version_lt2013b()
    v_num = cosmo_wtf('version_number');
    tf = v_num(1) < 8 || (v_num(1) == 8 && v_num(2) < 2);

function test_parallel_get_nproc_available_matlab_lt2013b()
    if ~cosmo_wtf('is_matlab') || ...
                ~version_lt2013b()
        cosmo_notify_test_skipped('Only for the Matlab platform <2013b');
        return
    end

    has_parallel_toolbox = cosmo_check_external('@distcomp', false) && ...
                                ~isempty(which('matlabpool'));

    aeq = @(expected_output, varargin)assertEqual(expected_output, ...
                                                  cosmo_parallel_get_nproc_available(varargin{:}));

    warning_state = cosmo_warning();
    state_resetter = onCleanup(@()cosmo_warning(warning_state));
    cosmo_warning('off');

    if has_parallel_toolbox
        func = @matlabpool;
        query_func = @()func('size');
        nproc_available = query_func();

        if nproc_available == 0
            func();
            nproc_available = query_func();
        end

        aeq(nproc_available, 'nproc', nproc_available);
    else
        aeq(1);
        aeq(1, 'nproc', 2);
    end

function test_parallel_get_nproc_available_octave()
    if ~cosmo_wtf('is_octave')
        cosmo_notify_test_skipped('Only for the Octave platform');
        return
    end

    has_parallel_toolbox = cosmo_check_external( ...
                                                'octave_pkg_parallel', false);

    aeq = @(expeced_output, varargin)assertEqual(expeced_output, ...
                                                 cosmo_parallel_get_nproc_available(varargin{:}));

    warning_state = cosmo_warning();
    state_resetter = onCleanup(@()cosmo_warning(warning_state));
    cosmo_warning('off');

    if has_parallel_toolbox
        nproc_available = nproc('overridable');
        aeq(nproc_available, 'nproc', nproc_available);
    else
        aeq(1);
        aeq(1, 'nproc', 2);
    end

function test_parallel_get_nproc_available_override_query_func
    for navailable = 1:3
        opt = struct();
        opt.nproc_available_query_func = @()mock_query_func(navailable);

        result = cosmo_parallel_get_nproc_available(opt);
        assertEqual(navailable, result);

        for nproc = [1:6, inf]
            % save warning state
            warning_state = cosmo_warning();
            warning_resetter = onCleanup(@()cosmo_warning(warning_state));

            cosmo_warning('reset');
            cosmo_warning('off');

            % check number of processes
            opt.nproc = nproc;
            result = cosmo_parallel_get_nproc_available(opt);
            expected = min(nproc, navailable);

            assertEqual(expected, result);

            % check whether a warning was shown if expected
            should_show_warning = nproc > navailable && isfinite(nproc);

            s = cosmo_warning();
            did_show_warning = ~isempty(s.shown_warnings);
            assertEqual(did_show_warning, should_show_warning);

            clear warning_resetter;
        end
    end

function test_parallel_get_nproc_available_error_with_bad_gcp
    if ~cosmo_wtf('is_matlab') || ...
                version_lt2013b()
        cosmo_notify_test_skipped('Only for the Matlab platform >=2013b');
        return
    end

    helper_test_with_bad_function('gcp');

function test_parallel_get_nproc_available_error_with_bad_matlabpool
    if ~cosmo_wtf('is_matlab') || ...
                ~version_lt2013b()
        cosmo_notify_test_skipped('Only for the Matlab platform <2013b');
        return
    end

    helper_test_with_bad_function('matlabpool');

function helper_test_with_bad_function(func_name)
    [pth, nm] = write_func(func_name, ...
                           {sprintf('function pool=%s()', func_name), ...
                            'error(''here'')'});
    cleaner = onCleanup(@()clean_func(pth, nm));
    addpath(pth);

    result = cosmo_parallel_get_nproc_available();
    assertEqual(result, 1);

function [pth, nm] = write_func(func_name, lines)
    pth = tempname();

    mkdir(pth);
    nm = [func_name '.m'];

    fn = fullfile(pth, nm);
    fid = fopen(fn, 'w');
    cleaner = onCleanup(@()fclose(fid));
    fprintf(fid, '%s\n', lines{:});

function clean_func(pth, nm)
    delete(fullfile(pth, nm));
    rmpath(pth);
    rmdir(pth);

function [output, msg] = throw_error()
    error('here');

function [output, msg] = mock_query_func(nproc)
    msg = '';
    output = nproc;

function test_parallel_get_nproc_available_exceptiont()
    warning_state = cosmo_warning();
    state_resetter = onCleanup(@()cosmo_warning(warning_state));
    cosmo_warning('off');

    aet = @(varargin)assertExceptionThrown(@() ...
                                           cosmo_parallel_get_nproc_available(varargin{:}), '');
    illegal_args = {{'foo'}, ...
                    {3}, ...
                    {'nproc', 0}, ...
                    {'nproc', -1}, ...
                    {struct('nproc', 1.6)} ...
                   };
    for k = 1:numel(illegal_args)
        args = illegal_args{k};
        aet(args{:});
    end