function predicted = cosmo_classify_lda(samples_train, targets_train, samples_test, opt)
% linear discriminant analysis classifier - without prior
%
% predicted=cosmo_classify_lda(samples_train, targets_train, samples_test[,opt])
%
% Inputs:
% samples_train PxR training data for P samples and R features
% targets_train Px1 training data classes
% samples_test QxR test data
% opt Optional struct with optional field:
% .regularization Used to regularize covariance matrix. Default .01
% .max_feature_count Used to set the maximum number of features,
% defaults to 5000. If R is larger than this
% value, an error is raised. This is a (conservative)
% safety limit to avoid huge memory consumption
% for large values of R, because training the
% classifier typically requires in the order of 8*R^2
% bytes of memory
%
% Output:
% predicted Qx1 predicted data classes for samples_test
%
% Example:
% ds=cosmo_synthetic_dataset('ntargets',5,'nchunks',15);
% test_chunk=1;
% te=cosmo_slice(ds,ds.sa.chunks==test_chunk);
% tr=cosmo_slice(ds,ds.sa.chunks~=test_chunk);
% pred=cosmo_classify_lda(tr.samples,tr.sa.targets,te.samples,struct);
% % show targets and predicted labels
% disp([te.sa.targets pred])
% %|| 1 1
% %|| 2 2
% %|| 3 3
% %|| 4 4
% %|| 5 5
%
% Notes:
% - this classifier does not support a prior, that is it assumes that all
% classes have the same number of samples. If that is not the case an
% error is thrown.
% - a safety limit of opt.max_feature_count is implemented because a large
% number of features can crash Matlab / Octave, and/or make it very slow.
%
% (Contributions by Joern Diedrichsen, Tobias Wiestler, Nikolaas Oosterhof)
% # For CoSMoMVPA's copyright information and license terms, #
% # see the COPYING file distributed with CoSMoMVPA. #
if nargin < 4 || isempty(opt)
opt = struct();
end
if ~isfield(opt, 'regularization')
opt.regularization = .01;
end
if ~isfield(opt, 'max_feature_count')
opt.max_feature_count = 5000;
end
% support repeated testing on different data after training every time
% on the same data. This is achieved by caching the training data
% and associated model
persistent cached_targets_train
persistent cached_samples_train
persistent cached_opt
persistent cached_model
if isequal(cached_targets_train, targets_train) && ...
isequal(cached_opt, opt) && ...
isequal(cached_samples_train, samples_train)
% use cache
model = cached_model;
else
% train classifier
model = train(samples_train, targets_train, opt);
% store model
cached_targets_train = targets_train;
cached_samples_train = samples_train;
cached_opt = opt;
cached_model = model;
end
% test classifier
predicted = test(model, samples_test);
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% helper functions
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function model = train(samples_train, targets_train, opt)
% train LDA classifier
[ntrain, nfeatures] = size(samples_train);
ntrain_ = numel(targets_train);
if ntrain_ ~= ntrain
error(['size mismatch: samples_train has %d rows, '...
'targets_train has %d values'], ntrain, ntrain_);
end
if nfeatures > opt.max_feature_count
% compute size requirements for numbers.
% tyically this is 8 bytes per number, unless single precision
% is used.
w = whos('samples_train');
size_per_number = w.bytes / numel(samples_train);
% the covariance matrix is nfeatures x nfeatures
mem_required = nfeatures^2 * size_per_number;
mem_required_mb = round(mem_required) / 1e6; % in megabytes
error(['A large number of features (%d) was found, '...
'exceeding the safety limit max_feature_count=%d '...
'for the %s function.\n'...
'This limit is imposed because computing the '...
'covariance matrix will require in the order of '...
'%.0f MB of memory, and inverting it may take '...
'a long time.\n'...
'The safety limit can be changed by setting '...
'the ''max_feature_count'' option to another '...
'value, but large values ***may freeze or crash '...
'the machine***.'], ...
nfeatures, opt.max_feature_count, ...
mfilename(), mem_required_mb);
end
classes = fast_vector_unique(targets_train);
nclasses = numel(classes);
class_mean = zeros(nclasses, nfeatures); % class means
class_cov = zeros(nfeatures); % within-class variability
% compute mean and (co)variance
for k = 1:nclasses
% select data in this class
msk = targets_train == classes(k);
% number of samples in k-th class
n = sum(msk);
if k == 1
if n < 2
error(['Need at least two samples per class '...
'in training']);
end
nfirst = n; % keep track of number of samples
elseif nfirst ~= n
error(['Different number of classes (%d and %d) - this '...
'is not supported. When using partitions, '...
'consider using cosmo_balance_partitions'], ...
n, nfirst);
end
class_samples = samples_train(msk, :);
class_mean(k, :) = sum(class_samples, 1) / n; % class mean
res = bsxfun(@minus, class_samples, class_mean(k, :)); % residuals
class_cov = class_cov + res' * res; % estimate common covariance matrix
end
% apply regularization
regularization = opt.regularization;
class_cov = class_cov / ntrain;
reg = eye(nfeatures) * trace(class_cov) / max(1, nfeatures);
class_cov_reg = class_cov + reg * regularization;
% linear discriminant
class_weight = class_mean / class_cov_reg;
class_offset = sum(class_weight .* class_mean, 2);
model = struct();
model.class_offset = class_offset;
model.class_weight = class_weight;
model.classes = classes;
function predicted = test(model, samples_test)
% test LDA classifier
class_offset = model.class_offset;
class_weight = model.class_weight;
classes = model.classes;
if size(samples_test, 2) ~= size(class_weight, 2)
error('test set has %d features, train set has %d', ...
size(class_weight, 2), size(samples_test, 2));
end
class_proj = bsxfun(@plus, -.5 * class_offset, class_weight * samples_test');
[unused, class_idxs] = max(class_proj);
predicted = classes(class_idxs);
function unq_xs = fast_vector_unique(xs)
xs_sorted = sort(xs(:));
idxs = ([true; diff(xs_sorted) > 0]);
unq_xs = xs_sorted(idxs);