cosmo confusion matrix skl

function [confusion_matrix, classes] = cosmo_confusion_matrix(ds, varargin)
    % Returns a confusion matrix
    %
    % Usage 1: mx=cosmo_confusion_matrix(ds)
    % Usage 2: mx=cosmo_confusion_matrix(targets, predicted)
    %
    %
    % Inputs:
    %   targets     Nx1 targets for N samples, or a dataset struct with
    %               .sa.targets
    %   predicted   NxM predicted labels (from a classifier), for N samples and
    %               M predictions per set of samples
    %
    % Returns:
    %   mx          PxPxM matrix assuming there are P unique targets.
    %               mx(i,j,k)==c means that the i-th target class was classified
    %               as the j-th target class c times for the k-th set of
    %               samples.
    %   classes     Px1 class labels.
    %
    % Example:
    %     ds=cosmo_synthetic_dataset('ntargets',3,'nchunks',4);
    %     args=struct();
    %     args.partitions=cosmo_nchoosek_partitioner(ds,1);
    %     args.output='winner_predictions';
    %     args.classifier=@cosmo_classify_lda;
    %     pred_ds=cosmo_crossvalidation_measure(ds,args);
    %     confusion=cosmo_confusion_matrix(pred_ds.sa.targets,pred_ds.samples);
    %     cosmo_disp(confusion)
    %     %|| [ 3         0         1
    %     %||   0         3         1
    %     %||   1         0         3 ]
    %     confusion_alt=cosmo_confusion_matrix(pred_ds);
    %     isequal(confusion,confusion_alt)
    %     %|| true
    %     %
    %     % run a searchlight with tiny radius of 1 voxel (3 is more common)
    %     nbrhood=cosmo_spherical_neighborhood(ds,'radius',1,'progress',false);
    %     measure=@cosmo_crossvalidation_measure;
    %     sl_ds=cosmo_searchlight(ds,nbrhood,measure,args,'progress',false);
    %     %
    %     % the confusion matrix is 3x3x6, that is 6 3x3 confusion
    %     % matrices. Here the dataset is passed directly
    %     sl_confusion=cosmo_confusion_matrix(sl_ds);
    %     cosmo_disp(sl_confusion)
    %     %|| <double>@3x3x6
    %     %||    (:,:,1) = [ 4         0         0
    %     %||                0         4         0
    %     %||                0         1         3 ]
    %     %||    (:,:,2) = [ 4         0         0
    %     %||                0         4         0
    %     %||                0         1         3 ]
    %     %||    (:,:,3) = [ 2         1         1
    %     %||                0         4         0
    %     %||                1         0         3 ]
    %     %||    (:,:,4) = [ 4         0         0
    %     %||                0         3         1
    %     %||                0         1         3 ]
    %     %||    (:,:,5) = [ 3         0         1
    %     %||                0         4         0
    %     %||                1         1         2 ]
    %     %||    (:,:,6) = [ 3         0         1
    %     %||                0         4         0
    %     %||                1         1         2 ]
    %
    %     % using samples that are not predictions gives an error
    %     ds=cosmo_synthetic_dataset('ntargets',3,'nchunks',4);
    %     confusion=cosmo_confusion_matrix(ds)
    %     %|| error('72 predictions mismatch targets, first is (1,1)=2.211999e+00')
    %
    % Notes:
    %   - this function counts the number of times each sample was classified
    %     as any target
    %
    % #   For CoSMoMVPA's copyright information and license terms,   #
    % #   see the COPYING file distributed with CoSMoMVPA.           #

    [targets, predicted] = get_data(ds, varargin{:});

    % see which classes there are
    [class_indices, classes] = cosmo_index_unique(targets);
    nclasses = numel(class_indices);

    % allocate space for output
    nfeatures = size(predicted, 2);
    confusion_matrix = zeros([nclasses, nclasses, nfeatures]);

    % keep track which predicted samples were in targets
    visited = false(size(predicted));
    %%%% >>> Your code here <<< %%%%

    missing = ~(visited | isnan(predicted));
    if any(missing(:))
        n = sum(missing(:));
        [i, j] = find(missing, 1);
        error(['%d predictions mismatch targets, '...
               'first is (%d,%d)=%d'], n, i, j, predicted(i, j));
    end

function [targets, predicted] = get_data(ds, predicted)
    has_predicted = nargin >= 2;
    is_ds = isstruct(ds);
    if is_ds
        if has_predicted
            error('Need exactly one argument when input is struct');
        end
        % input is a dataset
        cosmo_isfield(ds, 'sa.targets', true);
        cosmo_isfield(ds, 'samples', true);

        predicted = ds.samples;
        targets = ds.sa.targets;
    elseif isnumeric(ds)
        if ~has_predicted
            error('Need two arguments when first argument is numeric');
        end
        targets = ds;
    else
        error('Illegal input: need struct or numeric vector');
    end

    if ~isvector(targets) || size(targets, 2) ~= 1
        error('targets must be column vector');
    end

    if numel(size(predicted)) ~= 2
        error('predictions must be matrix');
    end

    nsamples = numel(targets);
    if size(predicted, 1) ~= nsamples
        error(['Size mismatch: predictions has %d values on first '...
               'dimension, but targets has %d values'], ...
              size(predicted, 1), nsamples);
    end