cosmo classify lda

function predicted = cosmo_classify_lda(samples_train, targets_train, samples_test, opt)
% linear discriminant analysis classifier - without prior
%
% predicted=cosmo_classify_lda(samples_train, targets_train, samples_test[,opt])
%
% Inputs:
%   samples_train       PxR training data for P samples and R features
%   targets_train       Px1 training data classes
%   samples_test        QxR test data
%   opt                 Optional struct with optional field:
%    .regularization    Used to regularize covariance matrix. Default .01
%    .max_feature_count Used to set the maximum number of features,
%                       defaults to 5000. If R is larger than this
%                       value, an error is raised. This is a (conservative)
%                       safety limit to avoid huge memory consumption
%                       for large values of R, because training the
%                       classifier typically requires in the order of 8*R^2
%                       bytes of memory
%
% Output:
%   predicted          Qx1 predicted data classes for samples_test
%
% Example:
%     ds=cosmo_synthetic_dataset('ntargets',5,'nchunks',15);
%     test_chunk=1;
%     te=cosmo_slice(ds,ds.sa.chunks==test_chunk);
%     tr=cosmo_slice(ds,ds.sa.chunks~=test_chunk);
%     pred=cosmo_classify_lda(tr.samples,tr.sa.targets,te.samples,struct);
%     % show targets and predicted labels
%     disp([te.sa.targets pred])
%     %||       1     1
%     %||       2     2
%     %||       3     3
%     %||       4     4
%     %||       5     5
%
% Notes:
% - this classifier does not support a prior, that is it assumes that all
%   classes have the same number of samples. If that is not the case an
%   error is thrown.
% - a safety limit of opt.max_feature_count is implemented because a large
%   number of features can crash Matlab / Octave, and/or make it very slow.
%
% (Contributions by Joern Diedrichsen, Tobias Wiestler, Nikolaas Oosterhof)
% #   For CoSMoMVPA's copyright information and license terms,   #
% #   see the COPYING file distributed with CoSMoMVPA.           #

    if nargin<4 || isempty(opt)
        opt=struct();
    end
    if ~isfield(opt, 'regularization'), opt.regularization=.01; end
    if ~isfield(opt, 'max_feature_count'), opt.max_feature_count=5000; end

    % support repeated testing on different data after training every time
    % on the same data. This is achieved by caching the training data
    % and associated model
    persistent cached_targets_train;
    persistent cached_samples_train;
    persistent cached_opt;
    persistent cached_model;

    if isequal(cached_targets_train, targets_train) && ...
            isequal(cached_opt, opt) && ...
            isequal(cached_samples_train, samples_train)
        % use cache
        model=cached_model;
    else
        % train classifier
        model=train(samples_train, targets_train, opt);

        % store model
        cached_targets_train=targets_train;
        cached_samples_train=samples_train;
        cached_opt=opt;
        cached_model=model;
    end

    % test classifier
    predicted=test(model, samples_test);

    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    % helper functions
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    function model=train(samples_train, targets_train, opt)
        % train LDA classifier

        [ntrain, nfeatures]=size(samples_train);
        ntrain_=numel(targets_train);

        if ntrain_~=ntrain
            error(['size mismatch: samples_train has %d rows, '...
                    'targets_train has %d values'],ntrain,ntrain_);
        end

        if nfeatures>opt.max_feature_count
            % compute size requirements for numbers.
            % tyically this is 8 bytes per number, unless single precision
            % is used.
            w=whos('samples_train');
            size_per_number=w.bytes/numel(samples_train);

            % the covariance matrix is nfeatures x nfeatures
            mem_required=nfeatures^2*size_per_number;
            mem_required_mb=round(mem_required)/1e6; % in megabytes

            error(['A large number of features (%d) was found, '...
                    'exceeding the safety limit max_feature_count=%d '...
                    'for the %s function.\n'...
                    'This limit is imposed because computing the '...
                    'covariance matrix will require in the order of '...
                    '%.0f MB of memory, and inverting it may take '...
                    'a long time.\n'...
                    'The safety limit can be changed by setting '...
                    'the ''max_feature_count'' option to another '...
                    'value, but large values ***may freeze or crash '...
                    'the machine***.'],...
                    nfeatures,opt.max_feature_count,...
                    mfilename(),mem_required_mb);
        end

        classes=fast_vector_unique(targets_train);
        nclasses=numel(classes);

        class_mean=zeros(nclasses,nfeatures);   % class means
        class_cov=zeros(nfeatures);              % within-class variability

        % compute mean and (co)variance
        for k=1:nclasses;
            % select data in this class
            msk=targets_train==classes(k);

            % number of samples in k-th class
            n=sum(msk);

            if k==1
                if n<2
                    error(['Need at least two samples per class '...
                                'in training']);
                end
                nfirst=n; % keep track of number of samples
            elseif nfirst~=n
                error(['Different number of classes (%d and %d) - this '...
                        'is not supported. When using partitions, '...
                        'consider using cosmo_balance_partitions'], ...
                         n, nfirst);
            end

            class_samples=samples_train(msk,:);

            class_mean(k,:) = sum(class_samples,1)/n; % class mean
            res = bsxfun(@minus,class_samples,class_mean(k,:)); % residuals
            class_cov = class_cov+res'*res; % estimate common covariance matrix
        end;
        % apply regularization
        regularization=opt.regularization;
        class_cov=class_cov/ntrain;
        reg=eye(nfeatures)*trace(class_cov)/max(1,nfeatures);
        class_cov_reg=class_cov+reg*regularization;

        % linear discriminant
        class_weight=class_mean/class_cov_reg;
        class_offset=sum(class_weight .* class_mean,2);

        model=struct();
        model.class_offset=class_offset;
        model.class_weight=class_weight;
        model.classes=classes;

function predicted=test(model, samples_test)
    % test LDA classifier

    class_offset=model.class_offset;
    class_weight=model.class_weight;
    classes=model.classes;

    if size(samples_test,2)~=size(class_weight,2)
        error('test set has %d features, train set has %d',...
                size(class_weight,2),size(samples_test,2));
    end

    class_proj=bsxfun(@plus,-.5*class_offset,class_weight*samples_test');

    [unused,class_idxs]=max(class_proj);
    predicted=classes(class_idxs);

function unq_xs=fast_vector_unique(xs)
    xs_sorted=sort(xs(:));
    idxs=([true;diff(xs_sorted)>0]);
    unq_xs=xs_sorted(idxs);