test pca

function test_suite = test_pca
    % tests for cosmo_pca
    %
    % #   For CoSMoMVPA's copyright information and license terms,   #
    % #   see the COPYING file distributed with CoSMoMVPA.           #
    try % assignment of 'localfunctions' is necessary in Matlab >= 2016
        test_functions = localfunctions();
    catch % no problem; early Matlab versions can use initTestSuite fine
    end
    initTestSuite;

function [pca_samples, coef, mu, expl] = helper_cosmo_pca_wrapper(samples, ...
                                                                  keep_count)
    if isnan(keep_count)
        args = {};
    else
        args = {keep_count};
    end

    [pca_samples, params] = cosmo_pca(samples, args{:});
    coef = params.coef;
    mu = params.mu;
    expl = params.explained;

function [pca_samples, coef, mu, expl] = helper_matlab_pca_wrapper(samples, ...
                                                                   keep_count)
    % PCA implementation using Matlab statistics toolbox
    cosmo_check_external('!pca', true);

    if isnan(keep_count)
        args = {};
    else
        if cosmo_wtf('is_octave') && keep_count <= size(samples, 2)
            % deal with recent (as of 2025) change in behaviour of Octave's
            % pca function
            n_rows = size(samples, 1);
            is_row_vector = n_rows == 1;

            upper = n_rows;
            if ~is_row_vector
                upper = upper - 1;
            end

            if keep_count > upper
                keep_count = upper;
            end
        end
        args = {'NumComponents', keep_count};
    end

    [coef, pca_samples, unused, unused, expl, mu] = pca(samples, args{:});
    expl = expl';

function test_pca_more_samples_than_features
    nsamples = ceil(rand() * 10) + 10;
    nfeatures = nsamples + 10;

    helper_test_pca_correspondence(nsamples, nfeatures);

function test_pca_more_features_than_samples
    nfeatures = ceil(rand() * 10) + 10;
    nsamples = nfeatures + 10;

    helper_test_pca_correspondence(nsamples, nfeatures);

function test_pca_col_vector
    nfeatures = 1;
    nsamples = ceil(rand() * 10) + 10;

    helper_test_pca_correspondence(nsamples, nfeatures);

function test_pca_row_vector
    nfeatures = ceil(rand() * 10) + 10;
    nsamples = 1;

    helper_test_pca_correspondence(nsamples, nfeatures);

function test_pca_near_square_samples
    nfeatures = ceil(rand() * 10) + 10;
    for nsamples = nfeatures + (-1:1)
        helper_test_pca_correspondence(nsamples, nfeatures);
    end

function test_pca_too_many_components
    nsamples = ceil(rand() * 10) + 10;
    nfeatures = nsamples;
    for nkeep = nfeatures + (-1:1)
        handle = @()cosmo_pca(rand(nsamples, nfeatures), nkeep);
        if nkeep > nfeatures
            assertExceptionThrown(handle, '');
        else
            handle(); % should be ok
        end
    end

function test_pca_regression
    xs = [2.032   -0.8918  -0.8258    1.163    1.157   -1.291; ...
          0.5838   1.844    1.166    -0.8484   3.493   -0.1991; ...
          -1.444   -0.2617  -1.921     3.085   -1.372    1.727; ...
          -0.5177   2.339    0.4412    1.856    0.4794   0.08323; ...
          1.191   -0.204   -0.2088    1.755   -0.9548   0.5012; ...
          -1.326    2.724    0.1476    0.5024   3.407   -0.4803];

    s = [-0.5008    2.8648   -0.7589   -0.5301    0.0144; ...
         3.5030    0.5915    0.2537    0.8888    0.0226; ...
         -3.8914   -1.6525   -0.7549    0.4558    0.0157; ...
         0.1140   -1.3350    1.0615   -0.7253    0.0293; ...
         -2.1851    1.0744    0.9553    0.2445   -0.0444; ...
         2.9603   -1.5432   -0.7566   -0.3338   -0.0376];

    coef = [0.0014    0.7569    0.3369   -0.0540    0.3234; ...
            0.4052   -0.5164    0.4105   -0.3355   -0.0290; ...
            0.3133    0.0275    0.6618    0.0901    0.1745; ...
            -0.4294   -0.2190   -0.0420   -0.5306    0.6797; ...
            0.7090    0.0436   -0.5070    0.0744    0.4815; ...
            -0.2251   -0.3313    0.1456    0.7677    0.4127];

    mu = [0.0864    0.9248   -0.2001    1.2522    1.0349    0.0568];

    explained = [64.7794   26.0994    6.0071    3.1059    0.0082];

    for nkeep = [NaN, 1:7]
        if isnan(nkeep)
            args = {};
            ncomp = 5;
        else
            args = {nkeep};
            ncomp = min(nkeep, 5);
        end

        if nkeep > 6
            assertExceptionThrown(@()cosmo_pca(xs, args{:}), '');
        else
            [xs_pca, param] = cosmo_pca(xs, args{:});

            tolerance_arg = {'absolute', 5e-3};
            assertElementsAlmostEqual(xs_pca, s(:, 1:ncomp), ...
                                      tolerance_arg{:});

            expected_fieldnames = {'coef', 'explained', 'mu'};
            assertEqual(sort(fieldnames(param)), ...
                        sort(expected_fieldnames(:)));
            assertElementsAlmostEqual(param.coef, coef(:, 1:ncomp), ...
                                      tolerance_arg{:});
            assertElementsAlmostEqual(param.mu, mu, ...
                                      tolerance_arg{:});
            assertElementsAlmostEqual(param.explained, explained, ...
                                      tolerance_arg{:});

        end
    end

function test_pca_basic_properties
    nfeatures = ceil(rand() * 10 + 10);
    nsamples = ceil(rand() * 10 + 10) + nfeatures;

    x = randn(nsamples, nfeatures);
    [y, param] = cosmo_pca(x);

    % explained variance is on diagonal
    d = y' * y;
    assertElementsAlmostEqual(100 * diag(d) / trace(d), param.explained');

    % components are orthogonal
    d_zero_diag = d - diag(diag(d));
    assertElementsAlmostEqual(d_zero_diag, zeros(nfeatures));

    % average is computed correctly
    assertElementsAlmostEqual(mean(x, 1), param.mu);

    % x can be reconstructed
    assertElementsAlmostEqual(x, bsxfun(@plus, param.mu, y * param.coef'));

function test_pca_exceptions()
    aet = @(varargin)assertExceptionThrown(@() ...
                                           cosmo_pca(varargin{:}), '');
    aet(struct);
    aet({1});
    aet(randn([2 2 2]));

function helper_test_pca_correspondence(nsamples, nfeatures)
    if cosmo_skip_test_if_no_external('!pca')
        return
    end

    for nkeep = [NaN, -1, 0, 1, ...
                 ceil(nsamples / 2), ceil(nfeatures / 2), ...
                 nsamples - 1, nfeatures - 1, ...
                 nsamples, nfeatures, nsamples + 1, nfeatures + 1]
        helper_test_pca_correspondence_nkeep(nsamples, nfeatures, nkeep);
    end

function helper_test_pca_correspondence_nkeep(nsamples, nfeatures, nkeep)
    x = rand(nsamples, nfeatures);
    try
        % if the following statement throws an exception, then
        % matlab's pca must also throw an exception
        [p1, c1, m1, e1] = helper_matlab_pca_wrapper(x, nkeep);
    catch
        % cosmo pca should also throw error
        assertExceptionThrown(@()helper_cosmo_pca_wrapper(x, nkeep), '');

        % all good, we are done, leave helper
        return
    end

    % no error, verify that output match
    [p2, c2, m2, e2] = helper_cosmo_pca_wrapper(x, nkeep);

    is_row_vector = nsamples == 1;
    if is_row_vector
        % 2025: deal with Octave's recent pca implementation
        % which seems to behave special for row vectors
        assert(isempty(p2));
        assert(isempty(c2));
        assert(isempty(e2));
    else
        tolerance_arg = {'relative', 1e-5};
        assertElementsAlmostEqual(p1, p2, tolerance_arg{:});
        assertElementsAlmostEqual(c1, c2, tolerance_arg{:});
        assertElementsAlmostEqual(m1, m2, tolerance_arg{:});
        assertElementsAlmostEqual(e1, e2, tolerance_arg{:});
    end

function test_pca_retain_is_row_vector()
    nsamples = ceil(10 + rand() * 10);
    x = randn(nsamples);
    [y, params] = cosmo_pca(x);
    assertEqual(size(params.explained), [1 nsamples - 1]);