test phase stat

function test_suite=test_phase_stat
% tests for test_phase_stat
%
% #   For CoSMoMVPA's copyright information and license terms,   #
% #   see the COPYING file distributed with CoSMoMVPA.           #
    try % assignment of 'localfunctions' is necessary in Matlab >= 2016
        test_functions=localfunctions();
    catch % no problem; early Matlab versions can use initTestSuite fine
    end
    initTestSuite;

function r=randint()
    r=ceil(rand()*10+10);

function test_phase_stat_basic_balanced_trials
    r=randint();

    % test with both balanced and unbalanced number of trials
    helper_test_phase_stat_with_trial_counts(r,r);


function test_phase_stat_basic_unbalanced_trials
    names={'pos','pop','pbi'};
    for k=1:numel(names)
        name=names{k};
        for delta=[-1,0,1]
            r=randint();
            ds=generate_phase_dataset(r,r+delta);
            opt=struct();
            opt.output=name;

            func=@()cosmo_phase_stat(ds,opt);
            is_balanced=delta==0;
            if is_balanced
                % should be ok
                func();
            else
                assertExceptionThrown(func,'');
            end
        end
    end

function helper_test_phase_stat_with_trial_counts(ntrials1,ntrials2)
    ds=generate_phase_dataset(ntrials1,ntrials2);

    names={'pos','pop','pbi'};
    for k=1:numel(names)
        name=names{k};

        helper_test_phase_stat_with_name(ds,name);
    end


function helper_test_phase_stat_with_name(ds,stat_name)
    opt=struct();
    opt.output=stat_name;
    % compute result
    result=cosmo_phase_stat(ds,opt);

    % compute expected result
    samples=ds.samples;

    t1=find(ds.sa.targets==1);
    t2=find(ds.sa.targets==2);

    assert(numel(t1)==numel(t2));

    itc1=compute_itc(samples(t1,:));
    itc2=compute_itc(samples(t2,:));
    itc_all=compute_itc(samples);

    expected_samples=compute_phase_stat(stat_name,itc1,itc2,itc_all);

    % verify output matches expected output
    assertElementsAlmostEqual(result.samples,expected_samples);
    assertEqual(ds.a,result.a);
    assertEqual(ds.fa,result.fa);
    assertEqual(struct(),result.sa);

function test_phase_stat_with_signal()
    names={'pos','pop','pbi'};
    for k=1:numel(names)
        name=names{k};

        helper_test_phase_stat_with_signal_with_name(name);
    end


function helper_test_phase_stat_with_signal_with_name(name)
    ds=cosmo_synthetic_dataset('ntargets',2,'nchunks',50);
    nsamples=size(ds.samples,1);
    ds.sa.chunks(:)=1:nsamples;
    ds=cosmo_slice(ds,1,2);

    % use small phase angle differences
    % generally increase the distance, which should lead to an increase in
    % PBI, POS and POP
    sd=pi/100;
    signals=0:10;
    nsignals=numel(signals);

    result=zeros(nsignals,1);
    for k=1:nsignals
        for target=[1,2]
            msk=ds.sa.targets==target;
            r=rand(sum(msk),size(ds.samples,2));

            rng=r;

            if target==1
                % add increasing difference between two classes
                rng=rng+signals(k);
            end

            angle=2*pi*rng*sd;

            x=exp(1i*angle);
            ds.samples(msk,:)=x;
        end

        s=cosmo_phase_stat(ds,'output',name);
        result(k)=s.samples;
    end

    assert(cosmo_corr(result,signals')>.5);


function s=compute_phase_stat(stat_name,itc1,itc2,itc_all)
    switch stat_name
        case 'pbi'
            s=(itc1-itc_all).*(itc2-itc_all);

        case 'pop'
            s=(itc1.*itc2)-itc_all.^2;

        case 'pos'
            s=(itc1+itc2)-2*itc_all;

        otherwise
            assert(false);
    end


function itc=compute_itc(samples)
    assert(~isreal(samples))
    s=samples./abs(samples);

    itc=abs(sum(s))/size(s,1);



function idx=select_randomly(targets,value,count)
    pos=find(targets(:)==value);
    [unused,rp]=sort(rand(numel(pos),1));

    idx=pos(rp(1:count));


function ds=generate_phase_dataset(varargin)
    ndatasets=numel(varargin);
    ds_cell=cell(ndatasets,1);
    for k=1:ndatasets
        ntrials=varargin{k};

        ds_k=cosmo_synthetic_dataset('seed',0,...
                                    'nchunks',ntrials,...
                                    'ntargets',1);
        ds_k.sa.targets(:)=k;

        ds_cell{k}=ds_k;
    end

    ds=cosmo_stack(ds_cell);
    ds.sa.chunks(:)=1:numel(ds.sa.chunks);

    sz=size(ds.samples);
    ds.samples=randn(sz)+1i*randn(sz);



function test_phase_stat_exceptions
    extra_args={'output','pbi'};
    aet=@(varargin)assertExceptionThrown(...
                    @()cosmo_phase_stat(varargin{:}),'');
    aet_arg=@(varargin)aet(varargin,extra_args{:});

    ds=cosmo_synthetic_dataset('ntargets',2,'nchunks',6);
    nsamples=size(ds.samples,1);
    sz=size(ds.samples);
    ds.samples=randn(sz)+1i*randn(sz);
    ds.sa.chunks(:)=1:nsamples;
    cosmo_phase_stat(ds,extra_args{:}); % ok

    % with targets being 2 or 3 is also ok
    ds.sa.targets=ds.sa.targets+1;
    cosmo_phase_stat(ds,extra_args{:}); % ok

     % input not imaginary
    bad_ds=ds;
    bad_ds.samples=randn(sz);
    aet_arg(bad_ds);

    % chunks not all unique
    bad_ds=ds;
    bad_ds.sa.chunks(1)=bad_ds.sa.chunks(2);
    aet_arg(bad_ds);

    % imbalance is not ok.
    bad_ds=ds;
    bad_ds.sa.targets(:)=[repmat([1 2],1,5),[1 1]];
    aet_arg(bad_ds);

    % bad values for samples_are_unit_length
    bad_samples_are_unit_length_cell={[],'',1,[true false]};
    for k=1:numel(bad_samples_are_unit_length_cell)
        arg={'samples_are_unit_length',...
                    bad_samples_are_unit_length_cell{k}};
        aet_arg(ds,arg{:});
    end

    % with samples_are_unit_length=true, raise exception if some values
    % are not unit length
    aet_arg(ds,'samples_are_unit_length',true);

    % number of classes must be exactly 2
    for bad_class_count=[1,3,4]
        bad_ds=ds;
        bad_ds.sa.targets(:)=mod(1:nsamples,bad_class_count)+1;
        idx=cosmo_index_unique(bad_ds.sa.targets);
        counts=cellfun(@numel,idx);
        assert(all(counts==counts(1))); % balanced counts
        aet_arg(bad_ds);                % yet an error is raised
    end

    % no samples
    bad_ds=cosmo_slice(ds,[],1);
    aet_arg(bad_ds);

    % single sample
    bad_ds=cosmo_slice(ds,1,1);
    aet_arg(bad_ds);


    % balancer function must be function handle
    aet_arg(bad_ds,'balancer_func',struct);

    % raise exception when called without the 'output' argument, or wrong
    % output
    aet(ds);
    aet(ds,'output','foo');
    aet(ds,'output',1);