odd-even classification with LDA classifier
- For CoSMoMVPA's copyright information and license terms, #
- see the COPYING file distributed with CoSMoMVPA. #
Contents
Define data
config=cosmo_config();
data_path=fullfile(config.tutorial_data_path,'ak6','s01');
ds = cosmo_fmri_dataset([data_path '/glm_T_stats_perrun.nii'], ...
'mask', [data_path '/vt_mask.nii']);
ds=cosmo_remove_useless_data(ds);
set sample attributes
ds.sa.targets = repmat((1:6)',10,1);
ds.sa.chunks = floor(((1:60)-1)/6)'+1;
classes = {'monkey','lemur','mallard','warbler','ladybug','lunamoth'};
ds.sa.labels = repmat(classes,1,10)';
Part 1: bird classification; train on even runs, test on odd runs
even_msk = mod(ds.sa.chunks,2)==0;
odd_msk = mod(ds.sa.chunks,2)==1;
ds_even = cosmo_slice(ds,even_msk);
ds_odd = cosmo_slice(ds,odd_msk);
discriminate between mallards and warblers
categories={'mallard','warbler'};
msk_even_birds=cosmo_match(ds_even.sa.labels,categories);
ds_even_birds=cosmo_slice(ds_even,msk_even_birds);
msk_odd_birds=cosmo_match(ds_odd.sa.labels,categories);
ds_odd_birds=cosmo_slice(ds_odd,msk_odd_birds);
fprintf('Even data:\n')
cosmo_disp(ds_even_birds);
fprintf('Odd data:\n')
cosmo_disp(ds_odd_birds);
train_samples=ds_even_birds.samples;
train_targets=ds_even_birds.sa.targets;
test_samples=ds_odd_birds.samples;
test_pred=cosmo_classify_lda(train_samples,train_targets,...
test_samples);
test_targets=ds_odd_birds.sa.targets;
fprintf('\ntarget predicted\n');
disp([test_targets test_pred])
accuracy = mean(test_pred==test_targets);
fprintf('\nLDA birds even-odd: accuracy %.3f\n', accuracy);
test_pred_nb=cosmo_classify_naive_bayes(train_samples,train_targets,...
test_samples);
test_targets=ds_odd_birds.sa.targets;
accuracy = mean(test_pred_nb==test_targets);
fprintf('\nNaive Bayes birds even-odd: accuracy %.3f\n', accuracy);
Even data:
.a
.vol
.mat
[ -3 0 0 121
0 3 0 -114
0 0 3 -11.1
0 0 0 1 ]
.xform
'scanner_anat'
.dim
[ 80 80 43 ]
.fdim
.labels
{ 'i'
'j'
'k' }
.values
{ [ 1 2 3 ... 78 79 80 ]@1x80
[ 1 2 3 ... 78 79 80 ]@1x80
[ 1 2 3 ... 41 42 43 ]@1x43 }
.sa
.targets
[ 3
4
3
:
4
3
4 ]@10x1
.chunks
[ 2
2
4
:
8
10
10 ]@10x1
.labels
{ 'mallard'
'warbler'
'mallard'
:
'warbler'
'mallard'
'warbler' }@10x1
.samples
[ 2.9 1.72 2.76 ... 4.16 2.56 2.95
1.89 1.37 3.05 ... 4.38 4.46 3.86
1.26 2.41 2.87 ... 1.96 3.74 3.34
: : : : : :
2.09 1.53 1.98 ... 3.21 2.94 4.08
2.13 1.22 2.12 ... 2.3 3.14 1.99
1.5 1.6 2.07 ... 0.491 1.1 1.78 ]@10x384
.fa
.i
[ 32 33 34 ... 28 29 29 ]@1x384
.j
[ 24 24 24 ... 25 25 26 ]@1x384
.k
[ 3 3 3 ... 9 9 9 ]@1x384
Odd data:
.a
.vol
.mat
[ -3 0 0 121
0 3 0 -114
0 0 3 -11.1
0 0 0 1 ]
.xform
'scanner_anat'
.dim
[ 80 80 43 ]
.fdim
.labels
{ 'i'
'j'
'k' }
.values
{ [ 1 2 3 ... 78 79 80 ]@1x80
[ 1 2 3 ... 78 79 80 ]@1x80
[ 1 2 3 ... 41 42 43 ]@1x43 }
.sa
.targets
[ 3
4
3
:
4
3
4 ]@10x1
.chunks
[ 1
1
3
:
7
9
9 ]@10x1
.labels
{ 'mallard'
'warbler'
'mallard'
:
'warbler'
'mallard'
'warbler' }@10x1
.samples
[ 1.3 0.646 0.591 ... 1.51 1.75 3.08
3.08 3.05 4.08 ... 1.88 2.26 2.92
2.13 2.69 2.99 ... 2.03 2.87 1.74
: : : : : :
1.47 0.507 1.87 ... 2.59 3.09 3.51
2.24 2.37 3.27 ... 4.61 2.13 4.32
1.36 1.31 1.52 ... 2.95 2.61 1.18 ]@10x384
.fa
.i
[ 32 33 34 ... 28 29 29 ]@1x384
.j
[ 24 24 24 ... 25 25 26 ]@1x384
.k
[ 3 3 3 ... 9 9 9 ]@1x384
target predicted
3 4
4 4
3 3
4 4
3 3
4 4
3 3
4 4
3 3
4 4
LDA birds even-odd: accuracy 0.900
Naive Bayes birds even-odd: accuracy 0.700
Part 2: all categories; train/test on even/odd runs and vice versa
train_samples=ds_even.samples;
train_targets=ds_even.sa.targets;
test_samples=ds_odd.samples;
test_pred=cosmo_classify_lda(train_samples,train_targets,test_samples);
test_targets=ds_odd.sa.targets;
accuracy = mean(test_pred==test_targets);
fprintf('\nLDA all categories even-odd: accuracy %.3f\n', accuracy);
train_samples=ds_odd.samples;
train_targets=ds_odd.sa.targets;
test_samples=ds_even.samples;
test_pred=cosmo_classify_lda(train_samples,train_targets,test_samples);
test_targets=ds_even.sa.targets;
accuracy = mean(test_pred==test_targets);
fprintf('\nLDA all categories odd-even: accuracy %.3f\n', accuracy);
LDA all categories even-odd: accuracy 0.767
LDA all categories odd-even: accuracy 0.733
Part 3: build confusion matrix
nclasses=numel(classes);
confusion_matrix=zeros(nclasses);
assert(isequal(unique(test_targets),(1:6)'));
for predicted=1:nclasses
for target=1:nclasses
match_mask=test_pred==predicted & test_targets==target;
match_count=sum(match_mask);
confusion_matrix(target,predicted)=match_count;
end
end
confusion_matrix_alt=cosmo_confusion_matrix(test_targets,test_pred);
if ~isequal(confusion_matrix,confusion_matrix_alt)
error('your confusion matrix does not match the expected output');
end
figure
imagesc(confusion_matrix,[0 5])
title('confusion matrix');
set(gca,'XTick',1:nclasses,'XTickLabel',classes);
set(gca,'YTick',1:nclasses,'YTickLabel',classes);
ylabel('target');
xlabel('predicted');
colorbar