function [confusion_matrix, classes]=cosmo_confusion_matrix(ds, varargin)
% Returns a confusion matrix
%
% Usage 1: mx=cosmo_confusion_matrix(ds)
% Usage 2: mx=cosmo_confusion_matrix(targets, predicted)
%
%
% Inputs:
% targets Nx1 targets for N samples, or a dataset struct with
% .sa.targets
% predicted NxM predicted labels (from a classifier), for N samples and
% M predictions per set of samples
%
% Returns:
% mx PxPxM matrix assuming there are P unique targets.
% mx(i,j,k)==c means that the i-th target class was classified
% as the j-th target class c times for the k-th set of
% samples.
% classes Px1 class labels.
%
% Example:
% ds=cosmo_synthetic_dataset('ntargets',3,'nchunks',4);
% args=struct();
% args.partitions=cosmo_nchoosek_partitioner(ds,1);
% args.output='winner_predictions';
% args.classifier=@cosmo_classify_lda;
% pred_ds=cosmo_crossvalidation_measure(ds,args);
% confusion=cosmo_confusion_matrix(pred_ds.sa.targets,pred_ds.samples);
% cosmo_disp(confusion)
% %|| [ 3 0 1
% %|| 0 3 1
% %|| 1 0 3 ]
% confusion_alt=cosmo_confusion_matrix(pred_ds);
% isequal(confusion,confusion_alt)
% %|| true
% %
% % run a searchlight with tiny radius of 1 voxel (3 is more common)
% nbrhood=cosmo_spherical_neighborhood(ds,'radius',1,'progress',false);
% measure=@cosmo_crossvalidation_measure;
% sl_ds=cosmo_searchlight(ds,nbrhood,measure,args,'progress',false);
% %
% % the confusion matrix is 3x3x6, that is 6 3x3 confusion
% % matrices. Here the dataset is passed directly
% sl_confusion=cosmo_confusion_matrix(sl_ds);
% cosmo_disp(sl_confusion)
% %|| <double>@3x3x6
% %|| (:,:,1) = [ 4 0 0
% %|| 0 4 0
% %|| 0 1 3 ]
% %|| (:,:,2) = [ 4 0 0
% %|| 0 4 0
% %|| 0 1 3 ]
% %|| (:,:,3) = [ 2 1 1
% %|| 0 4 0
% %|| 1 0 3 ]
% %|| (:,:,4) = [ 4 0 0
% %|| 0 3 1
% %|| 0 1 3 ]
% %|| (:,:,5) = [ 3 0 1
% %|| 0 4 0
% %|| 1 1 2 ]
% %|| (:,:,6) = [ 3 0 1
% %|| 0 4 0
% %|| 1 1 2 ]
%
% % using samples that are not predictions gives an error
% ds=cosmo_synthetic_dataset('ntargets',3,'nchunks',4);
% confusion=cosmo_confusion_matrix(ds)
% %|| error('72 predictions mismatch targets, first is (1,1)=2.211999e+00')
%
% Notes:
% - this function counts the number of times each sample was classified
% as any target
%
% # For CoSMoMVPA's copyright information and license terms, #
% # see the COPYING file distributed with CoSMoMVPA. #
[targets,predicted]=get_data(ds,varargin{:});
% see which classes there are
[class_indices,classes]=cosmo_index_unique(targets);
nclasses=numel(class_indices);
% allocate space for output
nfeatures=size(predicted,2);
confusion_matrix=zeros([nclasses,nclasses,nfeatures]);
% keep track which predicted samples were in targets
visited=false(size(predicted));
for k=1:nclasses
% rows for k-th class
idxs=class_indices{k};
for j=1:nclasses
match_msk=bsxfun(@eq,classes(j),predicted(idxs,:));
confusion_matrix(k,j,:)=sum(match_msk,1);
visited(idxs,:)=visited(idxs,:) | match_msk;
end
end
missing=~(visited | isnan(predicted));
if any(missing(:))
n=sum(missing(:));
[i,j]=find(missing,1);
error(['%d predictions mismatch targets, '...
'first is (%d,%d)=%d'],n,i,j,predicted(i,j));
end
function [targets,predicted]=get_data(ds, predicted)
has_predicted=nargin>=2;
is_ds=isstruct(ds);
if is_ds
if has_predicted
error('Need exactly one argument when input is struct');
end
% input is a dataset
cosmo_isfield(ds,'sa.targets',true);
cosmo_isfield(ds,'samples',true);
predicted=ds.samples;
targets=ds.sa.targets;
elseif isnumeric(ds)
if ~has_predicted
error('Need two arguments when first argument is numeric');
end
targets=ds;
else
error('Illegal input: need struct or numeric vector');
end
if ~isvector(targets) || size(targets,2)~=1
error('targets must be column vector');
end
if numel(size(predicted))~=2
error('predictions must be matrix');
end
nsamples=numel(targets);
if size(predicted,1)~=nsamples
error(['Size mismatch: predictions has %d values on first '...
'dimension, but targets has %d values'],...
size(predicted,1),nsamples);
end