cosmo rand

function result=cosmo_rand(varargin)
% generate uniform pseudo-random numbers, optionally using a seed value
%
% result=cosmo_rand(s1,...,sN,['seed',seed])
%
% Input:
%    s*              scalar or vector indicating dimensions of the result
%    'seed', seed    (optional) if provided, use this seed value for
%                    pseudo-random number generation
%
% Output:
%    result          array of size s1 x s2 x ... sN. If the seed option is
%                    used, repeated calls with the same seed and element
%                    dimensions gives the same result
% Example:
%     % generate 2x2 pseudo-random number matrices twice, just like 'rand'
%     % (repeated calls give different outputs)
%     x1=cosmo_rand(2,2);
%     x2=cosmo_rand(2,2);
%     isequal(x1,x2)
%     %|| false
%     %
%     % as above, but specify a seed; repeated calls give the same output
%     x3=cosmo_rand(2,2,'seed',314);
%     x4=cosmo_rand(2,2,'seed',314);
%     isequal(x3,x4)
%     %|| true
%     %
%     % using a different seed gives a different output
%     x5=cosmo_rand(2,2,'seed',315);
%     isequal(x3,x5)
%     %|| false
%
%
% Notes:
%   - this function behaves identically to the builtin 'rand' function,
%     except that it supports a 'seed' option, which allows for
%     deterministic pseudo-number generation
%   - when using the 'seed' option, this function gives identical output
%     under both matlab and octave. To achieve this, the PRNG is set to a
%     different state for the two platforms
%   - this function uses the Mersenne twister algorithm by default, even
%     when 'seed' is used (unlike Matlab and Octave).
%
% #   For CoSMoMVPA's copyright information and license terms,   #
% #   see the COPYING file distributed with CoSMoMVPA.           #

    [sizes,seed,class_func]=process_input(varargin{:});

    randomizer=@rand; % default
    if seed~=0
        is_matlab=cosmo_wtf('is_matlab');

        if is_matlab
            rng_state=get_mersenne_state_from_seed(seed, is_matlab);

            stream=RandStream('mt19937ar','Seed',rng_state.Seed);
            stream.State=rng_state.State;

            randomizer=@stream.rand;
        else
            % preserve old PRNG state
            orig_rng_state=rand('state');
            cleaner=onCleanup(@()rand('state',orig_rng_state));

            % set random number generation state
            rng_state=get_mersenne_state_from_seed(seed, is_matlab);
            rand('state',rng_state);
        end
    end

    result=class_func(randomizer(sizes));


function rng_state=get_mersenne_state_from_seed(seed, is_matlab)
    % set the PRNG of the mersenne twister based on seed
    %
    % based on pseudo-code from wikipedia:
    % http://en.wikipedia.org/wiki/Mersenne_twister
    persistent cached_seed
    persistent cached_rng_state

    if isequal(cached_seed,seed)
        rng_state=cached_rng_state;
        return;
    end

    max_uint32=2^32-1;
    state=uint64(zeros(625,1));
    state(1)=bitand(uint64(seed),max_uint32);

    mersenne_mult=uint64(1812433253);

    for j=1:623
        v=mersenne_mult.*bitxor(state(j),bitshift(state(j),-30))+uint64(j);
        state(j+1)=bitand(v,max_uint32);
    end

    state(end)=1;

    if is_matlab
        % reverse counter relative to Octave
        % (this is undocumented in both Matlab and Octave)
        state(end)=uint64(625)-state(end);

        % matlab uses a struct to set the state
        rng_state=struct();
        rng_state.State=uint32(state);
        rng_state.Type='twister';
        rng_state.Seed=uint32(0);
    else
        % octave uses a vector to set the state
        rng_state=state;
    end

    cached_rng_state=rng_state;
    cached_seed=seed;

function x=identify_func(x)
    % do nothing

function [sizes,seed,class_func]=process_input(varargin)
    persistent cached_varargin;
    persistent cached_sizes;
    persistent cached_seed;
    persistent cached_class_func;
    if isequal(varargin,cached_varargin)
        sizes=cached_sizes;
        seed=cached_seed;
        class_func=cached_class_func;
        return;
    end

    n=numel(varargin);

    seed=0;
    sizes_cell=cell(1,n);
    class_func=[];

    has_processed_sizes=false;

    % process each argument
    k=0;
    while k<n
        k=k+1;
        arg=varargin{k};
        if isnumeric(arg)
            if has_processed_sizes
                error('size argument not allowed after string argument');
            end

            ensure_positive_vector(k,arg);
            sizes_cell{k}=arg(:)';

        elseif ischar(arg)
            has_processed_sizes=true;
            switch arg
                case 'single'
                    if ~isempty(class_func)
                        error('type can only be set once');
                    end
                    class_func=@single;

                case 'double'
                    if ~isempty(class_func)
                        error('type can only be set once');
                    end
                    class_func=@identify_func;

                case 'seed'
                    k=k+1;
                    if k>n
                        error('missing value after key ''%s''', arg);
                    end
                    value=varargin{k};
                    ensure_positive_scalar(k,value);
                    seed=value;

                otherwise
                    error('unsupported key ''%s''', arg);
            end

        else
            error('illegal input at position %d', k);
        end
    end

    sizes=[sizes_cell{:}];

    % no size provided, output is scalar
    if isempty(sizes)
        sizes=1;
    end

    if isempty(class_func)
        class_func=@identify_func;
    end

    cached_varargin=varargin;
    cached_sizes=sizes;
    cached_seed=seed;
    cached_class_func=class_func;

function ensure_positive_scalar(k,arg)
    ensure_positive_vector(k,arg);
    if ~isscalar(arg)
        error('argument at position %d is not a scalar',k);
    end

function ensure_positive_vector(k,arg)
    if ~isvector(arg) || ~isnumeric(arg) || ~(all(arg>=0))
        error('argument at position %d is not positive',k);
    end