test phase itc

function test_suite = test_phase_itc
    % tests for test_phase_itc
    %
    % #   For CoSMoMVPA's copyright information and license terms,   #
    % #   see the COPYING file distributed with CoSMoMVPA.           #
    try % assignment of 'localfunctions' is necessary in Matlab >= 2016
        test_functions = localfunctions();
    catch % no problem; early Matlab versions can use initTestSuite fine
    end
    initTestSuite;

function r = randint()
    r = ceil(2 + rand() * 10);

function test_phase_itc_basics
    nclasses = randint();
    classes = 1:2:(2 * nclasses);

    nrepeats = randint();
    nfeatures = randint();

    ds = generate_random_dataset(classes, nrepeats, nfeatures);

    % compute expected ITC
    itc_ds = cosmo_phase_itc(ds);
    expected_samples = zeros(nclasses + 1, nfeatures);
    for k = 1:nclasses
        msk = ds.sa.targets == classes(k);
        expected_samples(k, :) = quick_itc(ds.samples(msk, :));
    end
    expected_samples(nclasses + 1, :) = quick_itc(ds.samples);

    % construct expected dataset
    expected_itc_ds = struct();
    expected_itc_ds.samples = expected_samples;
    expected_itc_ds.sa.targets = [classes, NaN]';
    expected_itc_ds.a = ds.a;
    expected_itc_ds.fa = ds.fa;

    assert_datasets_almost_equal(itc_ds, expected_itc_ds);

function test_phase_itc_unit_length()
    % test with 'samples_are_unit_length' option
    ds = generate_random_dataset(1:10, randint(), randint());
    ds_unit = ds;
    ds_unit.samples = ds_unit.samples ./ abs(ds_unit.samples);

    itc_ds = cosmo_phase_itc(ds);
    itc_unit_ds = cosmo_phase_itc(ds_unit, 'samples_are_unit_length', true);

    assert_datasets_almost_equal(itc_ds, itc_unit_ds);

function test_phase_itc_sdim_field
    ds = cosmo_synthetic_dataset('ntargets', 3, 'nchunks', 3);
    ds.samples = ds.samples + 1i * randn(size(ds.samples));
    ds.sa.chunks(:) = 1:9;

    % add sample dimension
    ds = cosmo_dim_insert(ds, 1, 1, {'foo'}, {[1:9]}, {[1:9]'});

    itc_ds = cosmo_phase_itc(ds);
    assert(~isfield(itc_ds.a, 'sdim'));
    assert(~isfield(itc_ds.sa, 'foo'));

function assert_datasets_almost_equal(p, q)
    assertElementsAlmostEqual(p.samples, q.samples);

    p = rmfield(p, 'samples');
    q = rmfield(q, 'samples');

    assertEqual(p, q);

function ds = generate_random_dataset(classes, nrepeats, nfeatures)
    nclasses = numel(classes);
    nsamples = nclasses * nrepeats;
    sz = [nsamples, nfeatures];
    ds = struct();
    ds.samples = randn(sz) + 1i * randn(sz);
    ds.sa.targets = repmat(classes, 1, nrepeats)';
    ds.sa.chunks = (1:nsamples)';
    ds.a = 'foo';
    ds.fa.bar = 1:nfeatures;

    % permute randomly
    ds = cosmo_slice(ds, cosmo_randperm(nsamples));

function itc = quick_itc(samples)
    s = samples ./ abs(samples);
    itc = abs(mean(s, 1));

function test_phase_itc_exceptions()
    aet = @(varargin)assertExceptionThrown( ...
                                           @()cosmo_phase_itc(varargin{:}), '');

    ds = cosmo_synthetic_dataset('ntargets', 2, 'nchunks', 6);
    nsamples = size(ds.samples, 1);
    sz = size(ds.samples);
    ds.samples = randn(sz) + 1i * randn(sz);
    ds.sa.chunks(:) = 1:nsamples;
    cosmo_phase_itc(ds); % ok

    % input not imaginary
    bad_ds = ds;
    bad_ds.samples = randn(sz);
    aet(bad_ds);

    % chunks not all unique
    bad_ds = ds;
    bad_ds.sa.chunks(1) = bad_ds.sa.chunks(2);
    aet(bad_ds);

    % imbalance
    bad_ds = ds;
    bad_ds.sa.targets(:) = [repmat([1 2], 1, 5), [1 1]];
    aet(bad_ds);

    % bad values for samples_are_unit_length
    bad_samples_are_unit_length_cell = {[], '', 1, [true false]};
    for k = 1:numel(bad_samples_are_unit_length_cell)
        arg = {'samples_are_unit_length', ...
               bad_samples_are_unit_length_cell{k}};
        aet(ds, arg{:});
    end

    % with samples_are_unit_length=true, raise exception if some values
    % are not unit length
    aet(ds, 'samples_are_unit_length', true);