cosmo check partitions skl

function is_ok=cosmo_check_partitions(partitions, ds, varargin)
% check whether partitions are balanced and not double-dippy
%
% cosmo_check_partitions(partitions, ds, varargin)
%
% Inputs:
%   partitions  struct with partitions, e.g. from
%               cosmo_{nfold,oddeven,nchoosek}_partitioner
%   ds          dataset struct with fields .sa.{targets,chunks}
%   opt         (optional) struct with optional field:
%     .unbalanced_partitions_ok    if set to true, then unbalanced
%                                  partitions (with a different number of
%                                  targets of each class in a chunk) is ok
% Output:
%    is_ok      boolean, true if partitions are ok
%
% Throws:
%   - an error if partitions are double dippy or (unless specified in opt)
%     not balanced
%
% Examples:
%     ds=struct();
%     ds.samples=zeros(9,2);
%     ds.sa.targets=[1 1 2 2 2 3 3 3 3]';
%     ds.sa.chunks=[1 2 2 1 1 1 2 2 2]';
%     % make unbalanced partitions
%     partitions=cosmo_nfold_partitioner(ds);
%     cosmo_check_partitions(partitions,ds);
%     %|| error('Unbalance in partition 1 [...]')
%     %
%     % disable unbalanced check
%     cosmo_check_partitions(partitions,ds,'unbalanced_partitions_ok',true)
%     %|| true
%     %
%     % balance partitions and check without unbalanced check
%     partitions=cosmo_balance_partitions(partitions,ds);
%     cosmo_check_partitions(partitions,ds,'unbalanced_partitions_ok',false)
%     %|| true
%     %
%     % make the partitions double dippy
%     partitions.train_indices{1}=partitions.test_indices{1};
%     cosmo_check_partitions(partitions,ds,'unbalanced_partitions_ok',true)
%     %|| error('double dipping in fold 1: chunk 1 is in train and test set')
%     %
%     % make partitions empty
%     partitions.train_indices{1}=[];
%     cosmo_check_partitions(partitions,ds);
%     %|| error('partition 1: .train_indices are empty')
%     %
%     % partitions have values outside range
%     partitions.train_indices{1}=100;
%     cosmo_check_partitions(partitions,ds);
%     %|| error('partition 1: .train_indices are outside range 1..9');
%     %
%     % use non-integers
%     partitions.train_indices{1}=1.5;
%     cosmo_check_partitions(partitions,ds);
%     %|| error('partition 1: .train_indices are not integers');
%
% Notes:
%   - the reason to require balancing by default is that chance level is
%     1/nclasses, which is useful for many subsequent analyses.
%   - if this function raises an exception for partitions, consider running
%     partitions=cosmo_balance_partitions(partitions,...).
%
% See also: cosmo_balance_partitions, cosmo_nfold_partitioner
%
% #   For CoSMoMVPA's copyright information and license terms,   #
% #   see the COPYING file distributed with CoSMoMVPA.           #

    % process input arguments
    defaults=struct();
    defaults.unbalanced_partitions_ok=false;

    params=cosmo_structjoin(defaults,varargin{:});

    % whether to check for equal number of samples of each class in
    % each chunks
    check_balance=~params.unbalanced_partitions_ok;

    % whether to check for the same chunk in train and test set
    check_double_dipping=true;

    % check dataset
    check_dataset(ds);

    % ensure it has targets and chunks
    cosmo_isfield(ds,{'sa.targets','sa.chunks'},true);
    targets=ds.sa.targets;
    chunks=ds.sa.chunks;

    if check_balance
        [classes,unused,sample2class]=unique(targets);
    end

    % must have .train_indices and .test_indices
    cosmo_isfield(partitions,{'train_indices','test_indices'},true);
    train_indices=partitions.train_indices;
    test_indices=partitions.test_indices;

    if ~iscell(train_indices) || ~iscell(test_indices)
        error('.train_indices and .test_indices must be a cell');
    end

    % ensure equal number of partitions for train and test
    npartitions=numel(train_indices);
    if npartitions~=numel(test_indices)
        error('Partition count mismatch for train and test: %d ~= %d',...
                    npartitions,numel(test_indices));
    end

    nsamples=numel(targets);
    unsorted_train_test_fold=false(npartitions,2);

    for k=1:npartitions
        train_idxs=train_indices{k};
        test_idxs=test_indices{k};

        check_range(train_idxs,nsamples,k,'train');
        check_range(test_idxs,nsamples,k,'test');

        if ~issorted(train_idxs)
            unsorted_train_test_fold(k,1)=true;
        end

        if ~issorted(test_idxs)
            unsorted_train_test_fold(k,2)=true;
        end

        if check_balance
            % counts of number of samples in each each class must be the
            % same
            train_classes=sample2class(train_idxs);
            h=histc(train_classes,1:numel(classes));

            h_nonzero_idxs=find(h>0);
            if k==1
                first_h_nonzero_idxs=h_nonzero_idxs;
            elseif ~isequal(first_h_nonzero_idxs,h_nonzero_idxs)
                error(['Different targets used for training '...
                        'in partition %d and %d. This is weird. '...
                        'Consider the following scenarios:\n'...
                        '(1) You made the partitions manually. It is '...
                        'possible that you made a mistake.\n'...
                        '(2) You try to do cross-decoding and '...
                        'partitions were defined using '...
                        'cosmo_nchoosek_partitioner. This usually '...
                        'requires a *re-assignment* of .sa.targets '...
                        'so that the unique targets values are the '...
                        'same for the samples used for training '...
                        'and for testing. Please read the '...
                        'documentation of cosmo_nchoosek_partitioner '...
                        '(especially the examples) carefully and '...
                        'verify that '...
                        'the .sa.targets are (re)assigned properly.\n'...
                        '(3) You used a cosmo_ function to set the '...
                        'partitions, but case 2 does not apply. '...
                        'It may indicate a bug; in that case, '...
                        'please get in touch with the CoSMoMVPA '...
                        'developers'],1,k);
            end

            h_nonzero=h(h_nonzero_idxs);

            if ~all(h_nonzero(1)==h_nonzero)
                idx=find(h_nonzero(1)~=h_nonzero,1);

                pos_first=h_nonzero_idxs(1);
                pos_other=h_nonzero_idxs(idx);

                error(['Unbalance in partition %d, '...
                     'classes %d (#=%d) and %d (#=%d). '...
                     'Consider the following scenarios:\n'...
                     '(1) the input is an MEEG dataset, or an fMRI '...
                     'dataset with missing samples (trials) in some '...
                     'of the runs. You probably want to use '...
                     'cosmo_balance_partitions.\n'...
                     '(2) all chunks are unique; this is either a (a) '...
                     'between-participants design and you '...
                     'try to discriminate between participants (e.g. '...
                     'patients versus controls); or this is an MEEG '...
                     'dataset where all samples are assumed to be '...
                     'independent. You probably want to use '...
                     'cosmo_independent_samples_partitioner to define '...
                     'the partitions\n'...
                     '(3) the input is an fMRI dataset using beta '...
                     'estimates or t-statistics estimated using the '...
                     'GLM, with one sample '...
                     'of each condition per run (e.g. nfold partition) '...
                     'or per set of runs (e.g. odd-even partition). '...
                     'Probably a mistake was made setting .sa.chunks '...
                     'or .sa.targets\n'...
                     '(4) you *really* know what you are doing: '...
                     'as a litmus test, you would be comfortable '...
                     'implementing a bootstrapping algorithm to '...
                     'estimate the cdf of your measure of interest '...
                     'under some null hypothesis. You can set '...
                     'unbalanced_partitions_ok to true as an option'],...
                      k, classes(pos_first), h(pos_first), ...
                      classes(pos_other), h(pos_other));
            end
        end

        if check_double_dipping
            % no sample allowed to be both in train and test indices
            if any(cosmo_match(chunks(train_idxs),chunks(test_idxs)))
                ctrain=chunks(train_idxs);
                m=cosmo_match(ctrain,chunks(test_idxs));
                idx=find(m,1);

                error(['double dipping in fold %d: chunk %d is in '...
                       'train and test set'],k,ctrain(idx));
            end
        end
    end

    has_unsorted_indices=any(unsorted_train_test_fold(:));
    if has_unsorted_indices
        fns={'train_indices','test_indices'};
        msg_parts={'','',['Unsorted partitions can lead to unintuitive '...
            'ordering of results (e.g. when using fold_predictions)']};
        for k=1:2
            fn=fns{k};
            folds=find(unsorted_train_test_fold(:,k));
            if ~isempty(folds)
                msg_parts{k}=sprintf('Unsorted %s in fold(s):%s',...
                                    fn,sprintf(' %d',folds));
            end
        end

        msk=~cellfun(@isempty,msg_parts);
        msg=cosmo_strjoin(msg_parts(msk),'\n');
        cosmo_warning(msg);
    end

    is_ok=true;

function check_range(idxs,nsamples,partition,label)
    msg='';
    if isempty(idxs)
        msg='empty';
    elseif ~isequal(idxs,round(idxs))
        msg='not integers';
    elseif min(idxs)<1 || max(idxs)>nsamples
        msg=sprintf('outside range 1:%d',nsamples);
    end

    if ~isempty(msg)
        error('partition %d: .%s_indices are %s',partition,label,msg);
    end

function check_dataset(ds)
    persistent cached_sa;
    if isstruct(ds) && isfield(ds,'sa')
        if ~isequal(ds.sa,cached_sa)
            cosmo_check_dataset(ds);
            cached_sa=ds.sa;
        end
        return
    end

    error('second input must be a dataset struct with field .sa');