cosmo classify lda skl

function predicted = cosmo_classify_lda(samples_train, targets_train, samples_test, opt)
    % linear discriminant analysis classifier - without prior
    %
    % predicted=cosmo_classify_lda(samples_train, targets_train, samples_test[,opt])
    %
    % Inputs:
    %   samples_train       PxR training data for P samples and R features
    %   targets_train       Px1 training data classes
    %   samples_test        QxR test data
    %   opt                 Optional struct with optional field:
    %    .regularization    Used to regularize covariance matrix. Default .01
    %    .max_feature_count Used to set the maximum number of features,
    %                       defaults to 5000. If R is larger than this
    %                       value, an error is raised. This is a (conservative)
    %                       safety limit to avoid huge memory consumption
    %                       for large values of R, because training the
    %                       classifier typically requires in the order of 8*R^2
    %                       bytes of memory
    %
    % Output:
    %   predicted          Qx1 predicted data classes for samples_test
    %
    % Example:
    %     ds=cosmo_synthetic_dataset('ntargets',5,'nchunks',15);
    %     test_chunk=1;
    %     te=cosmo_slice(ds,ds.sa.chunks==test_chunk);
    %     tr=cosmo_slice(ds,ds.sa.chunks~=test_chunk);
    %     pred=cosmo_classify_lda(tr.samples,tr.sa.targets,te.samples,struct);
    %     % show targets and predicted labels
    %     disp([te.sa.targets pred])
    %     %||       1     1
    %     %||       2     2
    %     %||       3     3
    %     %||       4     4
    %     %||       5     5
    %
    % Notes:
    % - this classifier does not support a prior, that is it assumes that all
    %   classes have the same number of samples. If that is not the case an
    %   error is thrown.
    % - a safety limit of opt.max_feature_count is implemented because a large
    %   number of features can crash Matlab / Octave, and/or make it very slow.
    %
    % (Contributions by Joern Diedrichsen, Tobias Wiestler, Nikolaas Oosterhof)
    % #   For CoSMoMVPA's copyright information and license terms,   #
    % #   see the COPYING file distributed with CoSMoMVPA.           #

    if nargin < 4 || isempty(opt)
        opt = struct();
    end
    if ~isfield(opt, 'regularization')
        opt.regularization = .01;
    end
    if ~isfield(opt, 'max_feature_count')
        opt.max_feature_count = 5000;
    end

    % support repeated testing on different data after training every time
    % on the same data. This is achieved by caching the training data
    % and associated model
    persistent cached_targets_train
    persistent cached_samples_train
    persistent cached_opt
    persistent cached_model

    if isequal(cached_targets_train, targets_train) && ...
            isequal(cached_opt, opt) && ...
            isequal(cached_samples_train, samples_train)
        % use cache
        model = cached_model;
    else
        % train classifier
        model = train(samples_train, targets_train, opt);

        % store model
        cached_targets_train = targets_train;
        cached_samples_train = samples_train;
        cached_opt = opt;
        cached_model = model;
    end

    % test classifier
    predicted = test(model, samples_test);

    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    % helper functions
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function model = train(samples_train, targets_train, opt)
    % train LDA classifier

    [ntrain, nfeatures] = size(samples_train);
    ntrain_ = numel(targets_train);

    if ntrain_ ~= ntrain
        error(['size mismatch: samples_train has %d rows, '...
               'targets_train has %d values'], ntrain, ntrain_);
    end

    if nfeatures > opt.max_feature_count
        % compute size requirements for numbers.
        % tyically this is 8 bytes per number, unless single precision
        % is used.
        w = whos('samples_train');
        size_per_number = w.bytes / numel(samples_train);

        % the covariance matrix is nfeatures x nfeatures
        mem_required = nfeatures^2 * size_per_number;
        mem_required_mb = round(mem_required) / 1e6; % in megabytes

        error(['A large number of features (%d) was found, '...
               'exceeding the safety limit max_feature_count=%d '...
               'for the %s function.\n'...
               'This limit is imposed because computing the '...
               'covariance matrix will require in the order of '...
               '%.0f MB of memory, and inverting it may take '...
               'a long time.\n'...
               'The safety limit can be changed by setting '...
               'the ''max_feature_count'' option to another '...
               'value, but large values ***may freeze or crash '...
               'the machine***.'], ...
              nfeatures, opt.max_feature_count, ...
              mfilename(), mem_required_mb);
    end

    classes = fast_vector_unique(targets_train);
    nclasses = numel(classes);

    class_mean = zeros(nclasses, nfeatures);   % class means
    class_cov = zeros(nfeatures);              % within-class variability

    % compute mean and (co)variance
    for k = 1:nclasses
        % select data in this class
        msk = targets_train == classes(k);

        % number of samples in k-th class
        n = sum(msk);

        if k == 1
            if n < 2
                error(['Need at least two samples per class '...
                       'in training']);
            end
            nfirst = n; % keep track of number of samples
        elseif nfirst ~= n
            error(['Different number of classes (%d and %d) - this '...
                   'is not supported. When using partitions, '...
                   'consider using cosmo_balance_partitions'], ...
                  n, nfirst);
        end

        class_samples = samples_train(msk, :);

        class_mean(k, :) = sum(class_samples, 1) / n; % class mean
        res = bsxfun(@minus, class_samples, class_mean(k, :)); % residuals
        class_cov = class_cov + res' * res; % estimate common covariance matrix
    end
    % apply regularization
    regularization = opt.regularization;
    class_cov = class_cov / ntrain;
    reg = eye(nfeatures) * trace(class_cov) / max(1, nfeatures);
    class_cov_reg = class_cov + reg * regularization;

    % linear discriminant
    class_weight = class_mean / class_cov_reg;
    class_offset = sum(class_weight .* class_mean, 2);

    model = struct();
    model.class_offset = class_offset;
    model.class_weight = class_weight;
    model.classes = classes;

function predicted = test(model, samples_test)
    % test LDA classifier

    class_offset = model.class_offset;
    class_weight = model.class_weight;
    classes = model.classes;

    if size(samples_test, 2) ~= size(class_weight, 2)
        error('test set has %d features, train set has %d', ...
              size(class_weight, 2), size(samples_test, 2));
    end

    class_proj = bsxfun(@plus, -.5 * class_offset, class_weight * samples_test');

    [unused, class_idxs] = max(class_proj);
    predicted = classes(class_idxs);

function unq_xs = fast_vector_unique(xs)
    xs_sorted = sort(xs(:));
    idxs = ([true; diff(xs_sorted) > 0]);
    unq_xs = xs_sorted(idxs);