cosmo cartprod skl

function p=cosmo_cartprod(xs, convert_to_numeric)
% returns the cartesian product with all combinations of the input
%
% p=cosmo_cartprod(xs[, convert_to_numeric])
%
% Inputs:
%   xs                   Px1 cell array with values for which the product
%                        is to be returned. Each element xs{k} should be
%                        - a cell with Qk values
%                        - a numeric array [xk_1,...,xk_Qk], which is
%                          interpreted as the cell {xk_1,...,xk_Qk}.
%                        - or a string s, which is interpreted as {s}.
%                        Alternatively xs can be a struct with P fieldnames
%                        where each value is a cell with Qk values.
%
%   convert_to_numeric   Optional; if true (default), then when the output
%                        contains numeric values only a numerical matrix is
%                        returned; otherwise a cell is returned.
% Output:
%   p                    QxP cartesian product of xs (where Q=Q1*...*Qk)
%                        containing all combinations of values in xs.
%                        - If xs is a cell, then p is represented by either
%                          a matrix (if all values in xs are numeric and
%                          convert_to_numeric==true) or a cell (in all
%                          other cases).
%                        - If xs is a struct, then p is a Qx1 cell. Each
%                          element in p is a struct with the same
%                          fieldnames as xs.
%
% Examples:
%     cosmo_cartprod({{1,2},{'a','b','c'}})'
%     %|| {1,2,1,2,1,2;
%     %|| 'a','a' ,'b','b','c','c'}
%
%     cosmo_cartprod({[1,2],[5,6,7]})'
%     %|| [1,2,1,2,1,2;
%     %||  5,5,6,6,7,7]
%
%     cosmo_cartprod(repmat({1:2},1,4))'
%     %|| [1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2;
%     %||  1 1 2 2 1 1 2 2 1 1 2 2 1 1 2 2;
%     %||  1 1 1 1 2 2 2 2 1 1 1 1 2 2 2 2;
%     %||  1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2]
%
%     s=struct();
%     s.roi={'v1','loc'};
%     s.hemi={'L','R'};
%     s.subj=[1 3 9];
%     s.ana='vis';
%     s.beta=4;
%     p=cosmo_cartprod(s)';
%     cosmo_disp(p);
%     %|| { .roi     .roi     .roi    ... .roi     .roi     .roi
%     %||     'v1'     'loc'    'v1'        'loc'    'v1'     'loc'
%     %||   .hemi    .hemi    .hemi       .hemi    .hemi    .hemi
%     %||     'L'      'L'      'R'         'L'      'R'      'R'
%     %||   .subj    .subj    .subj       .subj    .subj    .subj
%     %||     [ 1 ]    [ 1 ]    [ 1 ]       [ 9 ]    [ 9 ]    [ 9 ]
%     %||   .ana     .ana     .ana        .ana     .ana     .ana
%     %||     'vis'    'vis'    'vis'       'vis'    'vis'    'vis'
%     %||   .beta    .beta    .beta       .beta    .beta    .beta
%     %||     [ 4 ]    [ 4 ]    [ 4 ]       [ 4 ]    [ 4 ]    [ 4 ] }@1x12
%     %
%
% #   For CoSMoMVPA's copyright information and license terms,   #
% #   see the COPYING file distributed with CoSMoMVPA.           #

    if nargin<2, convert_to_numeric=true; end

    as_struct=isstruct(xs);

    if as_struct
        % input is a struct; put the values in each field in a cell.
        [xs,fns]=struct2cell(xs);
    elseif ~iscell(xs)
        error('Unsupported input: expected a cell or struct');
    end

    if isempty(xs)
        p=cell(1,0);
        return
    end

    p=cartprod(xs);

    % if input was a struct, output is a cell with structs
    if as_struct
        p=cell2structs(p, fns);
    elseif convert_to_numeric && ~isempty(p) && ...
                        all(cellfun(@isnumeric,p(:)))
        % all values are numeric; convert to numeric matrix
        p=reshape([p{:}],size(p));
    end

function p=cartprod(xs)

    ndim=numel(xs);

    % get values in first dimension (the 'head')
    xhead=xs{1};
    if isnumeric(xhead) || islogical(xhead)
        % put numeric arrays in a cell
        xhead=num2cell(xhead);
    elseif ischar(xhead)
        xhead={xhead};
    end

    % ensure head is a column vector
    xhead=xhead(:);

    if ndim==1
        p=xhead;
    else
        % use recursion to find cartprod of remaining dimensions
        % (the 'tail')
        xtail=xs(2:end);
        ptail=cartprod(xtail); % ensure output is always a cell

        % get sizes of head and tail
        nhead=numel(xhead);
        ntail=size(ptail,1);

        % allocate space for output
        rows=cell(ntail,1);
        for k=1:ntail
            % merge head and tail
            % ptailk_rep is a repeated version of the k-th tail row
            % to match the number of rows in head
            ptailk_rep=repmat(ptail(k,:),nhead,1);
            rows{k}=cat(2,xhead,ptailk_rep);
        end

        % stack the rows vertically
        p=cat(1,rows{:});
    end



function [c,fns]=struct2cell(xs)
    fns=fieldnames(xs);
    ndim=numel(fns);
    c=cell(1,ndim); % space for values in each dimension
    for k=1:ndim
        c{k}=xs.(fns{k});
    end

function struct_cell=cell2structs(p, fns)
    % number of output
    n=size(p,1);
    ndim=numel(fns);

    % allocate space for structs
    struct_cell=cell(n,1);

    % set values for each struct
    for k=1:n
        s=struct();
        for j=1:ndim
            % use the same fieldnames as in the input
            s.(fns{j})=p{k,j};
        end
        struct_cell{k}=s;
    end