test distatis

function test_suite = test_distatis
    % tests for cosmo_distatis
    %
    % #   For CoSMoMVPA's copyright information and license terms,   #
    % #   see the COPYING file distributed with CoSMoMVPA.           #
    try % assignment of 'localfunctions' is necessary in Matlab >= 2016
        test_functions = localfunctions();
    catch % no problem; early Matlab versions can use initTestSuite fine
    end
    initTestSuite;

function test_statis_
    % using: 1. Abdi, H. & Valentin, D. in Encyclopedia of Measurement
    %        and Statistics (Salkind, N.) 42?42 (SAGE Publications, 2007).

    d = cell(0);
    % note: element [1,3] is reported as .148, for symmetry use .146
    d{1} = [0    0.1120    0.1460    0.0830    0.1860    0.1100
            0.1120         0    0.1520    0.0980    0.1580    0.1340
            0.1460    0.1520         0    0.2020    0.2850    0.2490
            0.0830    0.0980    0.2020         0    0.1310    0.1100
            0.1860    0.1580    0.2850    0.1310         0    0.1550
            0.1100    0.1340    0.2490    0.1100    0.1550         0];

    d{2} = [0    0.6000    1.9800    0.4200    0.1400    0.5800
            0.6000         0    2.1000    0.7800    0.4200    1.3400
            1.9800    2.1000         0    2.0200    1.7200    2.0600
            0.4200    0.7800    2.0200         0    0.5000    0.8800
            0.1400    0.4200    1.7200    0.5000         0    0.3000
            0.5800    1.3400    2.0600    0.8800    0.3000         0];

    d{3} = [0    0.5400    1.3900    5.7800   10.2800    6.7700
            0.5400         0    1.0600    3.8000    6.8300    4.7100
            1.3900    1.0600         0    8.0100   11.0300    5.7200
            5.7800    3.8000    8.0100         0    2.5800    6.0900
            10.2800    6.8300   11.0300    2.5800         0    3.5300
            6.7700    4.7100    5.7200    6.0900    3.5300         0];
    d{4} = [0    0.0140    0.1590    0.0040    0.0010    0.0020
            0.0140         0    0.0180    0.0530    0.0240    0.0040
            0.1590    0.0180         0    0.2710    0.0670    0.0530
            0.0040    0.0530    0.2710         0    0.0010    0.0080
            0.0010    0.0240    0.0670    0.0010         0    0.0070
            0.0020    0.0040    0.0530    0.0080    0.0070         0];

    ds = get_distance_dataset(d);
    ds = cosmo_stack({ds, ds, ds}, 2); % features

    cosmo_check_dataset(ds);

    opt = struct();
    opt.split_by = 'subject';
    opt.return = 'crossproduct';
    opt.progress = false;
    res = cosmo_distatis(ds, opt);

    % note: S_{[+]}[6,2] is reported as -0.01, should be -.100
    s = [.176 .004 -.058 .014 -.100 -.036
         .004 .178 .022 -.038 -.068 -.100
         -.058 .022 .579 -.243 -.186 -.115
         .014 -.038 -.243 .240 .054 -.027
         -.100 -.068 -.186 .054 .266 .034
         -.036 -.100 -.115 -.027 .034 .243];

    u = cosmo_unflatten(res, 1);

    assertElementsAlmostEqual(repmat(s, [1, 1, 3]), u, 'absolute', .001);
    assertElementsAlmostEqual(res.fa.quality(2), .6551, 'absolute', .001);

    opt.return = 'distance';
    opt.shape = 'square';
    res = cosmo_distatis(ds, opt);
    u = cosmo_unflatten(res, 1);
    sq = cosmo_squareform(u(:, :, 1));
    assertElementsAlmostEqual(sq, [0.3452    0.8710    0.3888    0.6419 ...
                                   0.4911    0.7112    0.4919    0.5789 ...
                                   0.6203    1.3049    1.2163    1.0512 ...
                                   0.3996  0.5354 0.4400], 'absolute', .001);
    opt.shape = 'triangle';
    resvec = cosmo_distatis(ds, opt);
    assertElementsAlmostEqual(resvec.samples(:, 1), sq');

    % test numeric input
    vec_samples = mat2cell(ds.samples(:, 1), ones(4, 1) * 15, 1);
    resvec2 = cosmo_distatis(vec_samples, opt);
    assertElementsAlmostEqual(resvec2.samples, resvec.samples(:, 1));
    assertEqual(cellfun(@numel, resvec2.a.sdim.values), ...
                cellfun(@numel, resvec.a.sdim.values));
    resvec2.a.sdim.values = resvec.a.sdim.values;

    resvec_single_feature = cosmo_slice(resvec, 1, 2);
    assertElementsAlmostEqual(resvec_single_feature.samples, ...
                              resvec2.samples);
    assertElementsAlmostEqual(resvec_single_feature.fa.quality, ...
                              resvec2.fa.quality);
    resvec2.samples = resvec_single_feature.samples;
    resvec2.fa.quality = resvec_single_feature.fa.quality;
    assertEqual(resvec2, resvec_single_feature);

    opt.weights = 'uniform';
    resvec = cosmo_distatis(ds, opt);
    assertElementsAlmostEqual(resvec.samples(:, 1)', ...
                              [0.3156 0.8625 0.3704 0.6043 0.4646 ...
                               0.6572 0.4788 0.5489 0.5783 1.3221 ...
                               1.1576 0.9885 0.3644 0.5068 0.4037], ...
                              'absolute', .001);

    % test exceptions
    aet = @(varargin)assertExceptionThrown(@() ...
                                           cosmo_distatis(varargin{:}, 'progress', false), '');

    % no compromise possible
    d2 = d;
    d2{4}(1, 5) = 3;
    d2{4}(5, 1) = 3;
    opt.weights = 'eig';
    ds2 = get_distance_dataset(d2);
    aet(ds2, opt);

    % illegal arguments
    ds.sa.chunks = ds.sa.subject;
    opt = struct();
    opt.shape = 'foo';
    aet(ds, opt);

    opt = struct();
    opt.weights = 'foo';
    aet(ds, opt);

    opt = struct();
    opt.return = 'foo';
    aet(ds, opt);

    opt = struct();
    % cannot deal with empty input
    aet({}, opt);

    % needs dataset or numeric input
    aet({false, true}, opt);

    % cannot take non-matrix input
    aet({zeros([2 2 2])}, opt);

    % illegal field in dataset
    ds_bad = ds;
    ds_bad.foo = 2;
    opt = struct();
    aet(ds_bad, opt);

function ds = get_distance_dataset(d)
    nsubj = numel(d);
    ds_all = cell(nsubj, 1);
    for k = 1:nsubj
        ds = struct();
        sq = cosmo_squareform(d{k});
        ds.samples = sq(:);
        nd = size(d{k}, 1);

        ns = size(ds.samples, 1);
        ds.sa.subject = k * ones(ns, 1);

        [i, j] = find(triu(repmat(1:nd, nd, 1)', 1)');

        ds.sa.targets1 = i;
        ds.sa.targets2 = j;
        faces = {'f1', 'f2', 'f3', 'f4', 'f5', 'f6'}';
        ds.a.sdim.values = {faces, faces};
        ds.a.sdim.labels = {'targets1', 'targets2'};

        ds_all{k} = ds;
    end

    ds = cosmo_stack(ds_all);