test confusion matrix

function test_suite = test_confusion_matrix
    % tests for cosmo_confusion_matrix
    %
    % #   For CoSMoMVPA's copyright information and license terms,   #
    % #   see the COPYING file distributed with CoSMoMVPA.           #
    try % assignment of 'localfunctions' is necessary in Matlab >= 2016
        test_functions = localfunctions();
    catch % no problem; early Matlab versions can use initTestSuite fine
    end
    initTestSuite;

function classes = test_confusion_matrix_basics()
    nsamples = 30;
    ntargets = 5;
    delta = 10;

    targets = [ceil(ntargets * rand(nsamples, 1)); randperm(ntargets)'] + delta;
    predicted = [ceil(ntargets * rand(nsamples, 1)); randperm(ntargets)'] + delta;

    [mx, classes] = cosmo_confusion_matrix(targets, predicted);

    assertEqual(classes, delta + (1:ntargets)');

    assertEqual(size(mx), [ntargets, ntargets]);

    for k = 1:ntargets
        for j = 1:ntargets
            count = sum(targets == (k + delta) & predicted == (j + delta));
            assertEqual(count, mx(k, j));
        end
    end

    ds = struct();
    ds.samples = predicted;
    ds.sa.targets = targets;

    [mx2, classes2] = cosmo_confusion_matrix(ds);
    assertEqual(mx, mx2);
    assertEqual(classes, classes2);

    predicted3 = predicted(randperm(numel(predicted)));
    mx3 = cosmo_confusion_matrix(targets, predicted3);

    ds.samples = [predicted predicted3(:)];
    [mx_both, classes3] = cosmo_confusion_matrix(ds);
    assertEqual(mx_both, cat(3, mx2, mx3));
    assertEqual(classes3, classes);

function test_confusion_matrix_exceptions
    aet = @(varargin)assertExceptionThrown(@() ...
                                           cosmo_confusion_matrix(varargin{:}), '');
    % size mismatch
    aet([1; 1], 1);

    % missing target
    aet([1; 1], [1; 2]);

    % no dataset
    aet(struct());
    aet({});

    ds = struct();
    ds.samples = 1;
    aet(ds, 1);
    ds.sa.targets = 1;
    % second argument with dataset
    aet(ds, 1);

    % missing argument with numeric
    aet(1);

    % target row vector
    aet([1 1], [1; 1]);
    aet([1; 1], [1 1]);

    % no target vector
    aet(eye(2), [1; 1]);

    % no target vector
    aet(ones([2 2 2]), [1; 1]);
    aet([1; 1], ones([2 2 2]));