test mask dim intersect

function test_suite = test_mask_dim_intersect
    % tests for test_mask_dim_intersect
    %
    % #   For CoSMoMVPA's copyright information and license terms,   #
    % #   see the COPYING file distributed with CoSMoMVPA.           #
    try % assignment of 'localfunctions' is necessary in Matlab >= 2016
        test_functions = localfunctions();
    catch % no problem; early Matlab versions can use initTestSuite fine
    end
    initTestSuite;

function test_mask_dim_intersect_feature_dim()
    helper_test_mask_dim_intersect(2);

function test_mask_dim_intersect_sample_dim()
    helper_test_mask_dim_intersect(1);

function helper_test_mask_dim_intersect(dim)
    other_dim = 3 - dim;

    n_ds = ceil(rand() * 2 + 2);
    keep_ratio = .8;

    ds_cell = cell(n_ds, 1);
    for k = 1:n_ds
        if dim == 1
            % optimalization due to somewhat slow dim_transpose
            sz = 'small';
        else
            sz = 'big';
        end

        ds = cosmo_synthetic_dataset('size', sz, 'seed', 0, ...
                                     'nchunks', 1, 'ntargets', 1);
        n_features = size(ds.samples, 2);
        ds.samples = 1:n_features;

        rp = randperm(n_features);
        keep_indices = round(1:(keep_ratio * n_features));

        ds_keep = cosmo_slice(ds, rp(keep_indices), 2);

        if dim == 1
            ds_keep = cosmo_dim_transpose(ds_keep, ds.a.fdim.labels, 1);
        end

        ds_indices = ds_keep;

        % fill with random data so that the function to test cannot
        % use the indices
        ds_data = ds_keep;
        ds_data.samples = randn(size(ds_keep.samples));

        % the first row has random data, and is used as input for %
        % cosmo_mask_dim_intersect; the second row has indices, and is
        % used to verify the proper contents of the datasets
        ds_cell{k} = cosmo_stack({ds_data, ds_indices}, other_dim);
    end

    ds_data_cell = cellfun(@(x)cosmo_slice(x, 1, other_dim), ds_cell, ...
                           'UniformOutput', false);

    [indices, ds_intersect_cell] = cosmo_mask_dim_intersect(ds_data_cell, dim);

    if dim == 2
        % must have same output using default second argument
        [indices2, ds_intersect_cell2] = cosmo_mask_dim_intersect( ...
                                                                  ds_data_cell);

        assertEqual(indices, indices2);
        assertEqual(ds_intersect_cell, ds_intersect_cell2);
    end

    % verify ds_intersect_cell based on indices
    assert(iscell(ds_intersect_cell));
    assertEqual(size(ds_intersect_cell), size(ds_data_cell));
    for k = 1:n_ds
        idx = indices{k};
        assert(all(idx > 0));
        assert(all(isfinite(idx)));
        assertEqual(cosmo_slice(ds_data_cell{k}, idx, dim), ...
                    ds_intersect_cell{k});
    end

    ds_indices_cell = cellfun(@(x)cosmo_slice(x, 2, other_dim), ds_cell, ...
                              'UniformOutput', false);

    % see which indices are common across all datasets
    ds_keep_indices = 1:n_features;
    for k = 1:n_ds
        ds_idx = ds_indices_cell{k}.samples;
        assertEqual(sort(ds_idx), unique(ds_idx));
        ds_keep_indices = intersect(ds_keep_indices, ds_idx);
    end

    % verify that indices match across all datasets
    for k = 1:n_ds
        % select indices
        idx = ds_indices_cell{k}.samples(indices{k});

        % indices must be unique
        if isempty(idx)
            assert(isempty(ds_keep_indices));
        else
            assertEqual(sort(idx(:)), sort(ds_keep_indices(:)));
        end

        % must return indices in the same order
        ds_sel = cosmo_slice(ds_indices_cell{k}, indices{k}, dim);
        if k == 1
            first_ds_sel = ds_sel;
        else
            assertEqual(first_ds_sel, ds_sel);
        end
    end

function test_mask_dim_intersect_identity
    % after permuting the features, dataset should be identical following
    % unpermuting them

    types = {'fmri', 'surface', 'source', 'timelock', 'timefreq'};

    for k = 1:numel(types)
        type = types{k};
        ds = cosmo_synthetic_dataset('size', 'big', 'type', type, ...
                                     'ntargets', 1, 'nchunks', 1);

        switch type
            case 'source'
                args = {'matrix_labels', {'pos'}};
            otherwise
                args = {};
        end
        n_features = size(ds.samples, 2);
        ds.samples(1, :) = 1:n_features;

        rp = randperm(n_features);
        ds = cosmo_slice(ds, rp, 2);

        [indices_cell, ds_cell] = cosmo_mask_dim_intersect({ds}, 2, args{:});
        assert(numel(indices_cell) == 1);
        assert(numel(ds_cell) == 1);

        idx = indices_cell{1};
        ds_perm = ds_cell{1};

        assertEqual(sort(idx), 1:n_features);
        ds_reordered = cosmo_slice(ds, idx, 2);
        assertEqual(ds_perm, ds_reordered);

        % indices must be sorted
        assertEqual(ds_perm.samples, 1:n_features);

    end

function test_mask_dim_intersect_exceptions()
    aet = @(varargin)assertExceptionThrown(@() ...
                                           cosmo_mask_dim_intersect(varargin{:}), '');
    % cannot deal with empty
    aet(cosmo_synthetic_dataset());

    % input must be a cell with datasets
    aet(struct);
    aet('foo');
    aet({struct, struct});

    % repeated features not supported
    ds = cosmo_synthetic_dataset('size', 'big');
    ds2 = cosmo_stack({ds, ds}, 2);
    aet({ds2});

    % not even a single feature duplicated is allowed
    col_index = ceil(rand() * size(ds.samples, 2));
    ds_extra_col = cosmo_stack({ds, cosmo_slice(ds, col_index, 2)}, 2);
    aet(ds_extra_col);

    % dim must be 1 or 2
    aet({ds}, 3);
    aet({ds}, struct);
    aet({ds}, [1 2]);

function test_mask_dim_intersect_missing_dim()
    ds1 = cosmo_synthetic_dataset();
    ds2 = ds1;
    ds2.fa = rmfield(ds2.fa, 'k');

    assertEqual(ds1.samples, ds2.samples);
    assertExceptionThrown(@()cosmo_mask_dim_intersect({ds1, ds2}), '');

function test_mask_dim_intersect_nonmatching_dim()
    ds1 = cosmo_synthetic_dataset();
    ds2 = ds1;
    ds2 = cosmo_dim_remove(ds2, 'k');

    assertEqual(ds1.samples, ds2.samples);
    assertExceptionThrown(@()cosmo_mask_dim_intersect({ds1, ds2}), '');

function test_mask_dim_intersect_renamed_dim()
    ds1 = cosmo_synthetic_dataset();
    ds2 = ds1;
    ds2 = cosmo_dim_rename(ds2, 'k', 'kk');

    assertEqual(ds1.samples, ds2.samples);
    assertExceptionThrown(@()cosmo_mask_dim_intersect({ds1, ds2}), '');