cosmo phase stat skl

function stat_ds=cosmo_phase_stat(ds,varargin)
% Compute phase perturbation, or opposition sum or product phase statistic
%
% stat_ds=cosmo_phase_stat(ds,...)
%
% Inputs:
%   ds                      dataset struct with fields:
%       .samples            PxQ complex matrix for P samples (trials,
%                           observations) and Q features (e.g. combinations
%                           of time points, frequencies and channels)
%       .sa.targets         Px1 array with trial conditions.
%                           There must be exactly two conditions, thus
%                           .sa.targets must have exactly two unique
%                           values. A balanced number of samples is
%                           requires, i.e. each of the two unique values in
%                           .sa.targets must occur equally often.
%       .sa.chunks          Px1 array indicating which samples can be
%                           considered to be independent. It is required
%                           that all samples are independent, therefore
%                           all values in .sa.chunks must be different from
%                           each other
%       .fa                 } optional feature attributes
%       .a                  } optional sample attributes
%  'output',p               Return statistic, one of the following:
%                           - 'pbi': phase bifurcation index
%                           - 'pos': phase opposition sum
%                           - 'pop': phase opposition product
%  'samples_are_unit_length',u  (optional)
%                           If u==true, then all elements in ds.samples
%                           are assumed to be already of unit length. If
%                           this is indeed true, this can speed up the
%                           computation of the output.
%  'check_dataset',c        (optional, default=true)
%                           if c==false, there is no check for consistency
%                           of the ds input.
%
% Output:
%   stat_ds                 struct with fields
%       .samples            1xQ array with 'pbi', 'pos', or 'pop' function
%       .a                  } if present in the input, then the output
%       .fa                 } contains these fields as well
%
% Notes:
%   - if a dataset is not balanced for number of trials, consider using
%     cosmo_balance_dataset to balance it.
%
% See also: cosmo_balance_dataset, cosmo_phase_itc
%
% #   For CoSMoMVPA's copyright information and license terms,   #
% #   see the COPYING file distributed with CoSMoMVPA.           #


    default=struct();
    default.check_dataset=true;
    opt=cosmo_structjoin(default,varargin{:});

    check_inputs(ds,opt);

    % compute inter-trial coherence for the two conditions using
    % cosmo_phase_itc, which will thrown an error if samples are not
    % balanced.
    itc_ds=cosmo_phase_itc(ds,opt);

    % itc_ds must have last entry to be NaN, indicating it is the ITC for
    % all trials together
    if size(itc_ds.samples,1)~=3
        error(['Input must have exactly two unique values '...
                        'for .sa.targets']);
    end

    assert(isequal([false;false;true],isnan(itc_ds.sa.targets)));

    % compute PBI, POP or POS
    itc1=itc_ds.samples(1,:);
    itc2=itc_ds.samples(2,:);
    itc_all=itc_ds.samples(3,:);

    stat=compute_phase_stat(opt.output,itc1,itc2,itc_all);

    % set result
    stat_ds=cosmo_slice(itc_ds,1,1,false);
    stat_ds.samples=stat;
    stat_ds.sa=struct();



function s=compute_phase_stat(name, itc1, itc2, itc_all)
    switch name
        case 'pbi'
            s=(itc1-itc_all).*(itc2-itc_all);

        case 'pop'
            s=(itc1.*itc2)-itc_all.^2;

        case 'pos'
            s=itc1+itc2-2*itc_all;

        otherwise
            assert(false,'this should not happen');
    end


function check_inputs(ds,opt)
    if opt.check_dataset
        cosmo_check_dataset(ds);
    end

    if ~(isstruct(ds) ...
            && isfield(ds,'samples') ...
            && isfield(ds,'sa') ...
            && isfield(ds.sa,'targets'))
        error(['first input must be struct with fields .samples and '...
                            '.sa.targets']);
    end


    if ~isfield(opt,'output')
        error('option ''output'' is required');
    end

    allowed_values={'pbi','pos','pop'};
    if ~cosmo_match({opt.output},allowed_values)
        error('option ''output'' must be one of: ''%s''',...
                cosmo_strjoin(allowed_values,''', '''));
    end