cosmo meeg baseline correct skl

function bl_ds=cosmo_meeg_baseline_correct(ds, reference, method)
% correct baseline of MEEG dataset
%
% bl_ds=cosmo_meeg_baseline_correct(ds, reference, method)
%
% Inputs:
%     ds            MEEG dataset struct with 'time' feature dimension
%     reference     Either:
%                   - interval [start, stop], with start and stop in
%                     seconds (in this case, parameters are estimated in
%                     this time interval)
%                   - MEEG dataset struct with 'time' feature dimension,
%                     with features  (in this case, parameters are
%                     estimated using this dataset)
%     method        One of:
%                   - 'absolute'   : bl=samples-mu
%                   - 'relative'   : bl=samples/mu
%                   - 'relchange'  : bl=(samples-mu)/mu
%                   - 'db'         : bl=10*log10(samples/mu)
%
% Output:
%     bl_ds         MEEG dataset struct with the same shape as ds,
%                   where samples are baseline corrected using reference.
%                   This is done separately for each combination of values
%                   in feature dimensions different from 'time', e.g.
%                   for 'chan' for a timelock dataset and for 'chan' and
%                   'freq' combinations for a timefreq dataset
%
%
% Examples:
%     % illustrate 'relative' baseline correction
%     ds=cosmo_synthetic_dataset('type','timelock','size','small');
%     ds=cosmo_slice(ds,1:2); % take first two samples
%     ds_rel=cosmo_meeg_baseline_correct(ds,[-.3,-.18],'relative');
%     cosmo_disp(ds_rel.samples);
%     %|| [ 1     0.572         1      -1.3         1      1.56
%     %||   1     -1.45         1      1.89         1    -0.171 ]
%
%     % illustrate 'absolute' baseline correction
%     ds=cosmo_synthetic_dataset('type','timelock','size','small');
%     ds=cosmo_slice(ds,1:2); % take first two samples
%     ds_abs=cosmo_meeg_baseline_correct(ds,[-.3,-.18],'absolute');
%     cosmo_disp(ds_abs.samples);
%     %|| [ 0    -0.869         0      2.05         0    -0.465
%     %||   0     -1.43         0      1.65         0     -1.36 ]
%
%     % illustrate use of another dataset as reference
%     ds=cosmo_synthetic_dataset('type','timelock','size','small');
%     ds=cosmo_slice(ds,1:2); % take first two samples
%     ref=cosmo_synthetic_dataset('type','timelock','size','small');
%     ref=cosmo_slice(ref,1:2);
%     ds_ref_relch=cosmo_meeg_baseline_correct(ds,ref,'relchange');
%     cosmo_disp(ds_ref_relch.samples);
%     %|| [ 0.272    -0.272     -7.72      7.72     -0.22      0.22
%     %||   -5.41      5.41    -0.309     0.309      1.41     -1.41 ]
%
% #   For CoSMoMVPA's copyright information and license terms,   #
% #   see the COPYING file distributed with CoSMoMVPA.           #

if isstruct(reference)
    f=@baseline_correct_ds;
elseif isnumeric(reference)
    f=@baseline_correct_interval;
else
    error('illegal reference: expected dataset struct or vector');
end

baseline_label='time';
check_dataset(ds, baseline_label);
bl_ds=f(ds, reference, baseline_label, method);

function check_dataset(ds, baseline_label)
    cosmo_check_dataset(ds);

    [dim, index]=cosmo_dim_find(ds,baseline_label,true);
    if dim~=2
        error(['''%s'' must be feature dimension, found as '...
                'sample dimension'], baseline_label);
    end

function bl=baseline_correct(samples, mu, method)
    assert(size(samples,1)==size(mu,1));
    assert(isvector(mu));
    assert(numel(size(samples))==2);

    switch method
        case 'absolute'
            bl=bsxfun(@minus,samples,mu);
        case 'relchange'
            % use absolute and relative, i.e. (samples-mu)/mu
            bl=baseline_correct(baseline_correct(samples,mu,'absolute'),...
                        mu,'relative');
        case 'relative'
            bl=bsxfun(@rdivide,samples,mu);
        case 'vssum'
            bl=bsxfun(@rdivide,baseline_correct(samples,mu,'absolute'),...
                        bsxfun(@plus,samples,mu));
        case 'db'
            bl=10*log10(baseline_correct(samples,mu,'relative'));
        otherwise
            error('illegal baseline correction method ''%s''',method);
    end


function bl_ds=baseline_correct_interval(ds,interval,baseline_label,method)
    if numel(interval)~=2
        error('interval must have two values');
    end

    matcher=@(x) interval(1) <= x & x <= interval(2);
    msk=cosmo_dim_match(ds,baseline_label,matcher);

    reference=cosmo_slice(ds,msk,2);
    bl_ds=baseline_correct_ds(ds,reference,baseline_label,method);


function bl_ds=baseline_correct_ds(ds,reference,baseline_label,method)
    check_dataset(reference,baseline_label);

    % ensure compatible at sample dimension
    [ds1,ds_sa]=first_feature(ds);
    [ref1,ref_sa]=first_feature(reference);
    if ~isequal(ref_sa, ds_sa)
        error('.sa mismatch for input and reference dataset');
    end

    % split both
    ds_split=split_by_other(ds, baseline_label);
    reference_split=split_by_other(reference, baseline_label);

    n=numel(ds_split);
    nref=numel(reference_split);
    if n~=nref
        error(['Input dataset has different number of dimension '...
                    'combinations (%d) than reference dataset (%d)'],...
                    n,nref);
    end

    for k=1:n
        part_ds=ds_split{k};
        part_reference=reference_split{k};

        fa_ds=cosmo_slice(part_ds.fa,1,2,'struct');
        fa_reference=cosmo_slice(part_reference.fa,1,2,'struct');

        if ~isequal(rmfield(fa_ds,baseline_label),...
                        rmfield(fa_reference,baseline_label))
            error(['Feature mismatch between '...
                        'input dataset and reference dataset']);
        end

        mu=mean(part_reference.samples,2);

        ds_split{k}.samples=baseline_correct(part_ds.samples,mu,method);
    end

    bl_ds=cosmo_stack(ds_split,2);

function ds_split=split_by_other(ds, baseline_label)
    other_dims=setdiff(ds.a.fdim.labels,{baseline_label});
    ds_split=cosmo_split(ds,other_dims,2);


function [ds1, sa]=first_feature(ds, remove_fields)
    ds1=cosmo_slice(ds,1,2);

    sa=struct();
    if isfield(ds,'sa')
        sa=ds.sa;
    end