mirror of
https://gitcode.com/gh_mirrors/ope/OpenFace.git
synced 2026-05-14 11:17:53 +00:00
More work on AU - WIP
This commit is contained in:
@@ -8,10 +8,24 @@ function [result, prediction] = svr_test_linear(test_labels, test_samples, model
|
||||
prediction(prediction>5)=5;
|
||||
|
||||
% using CCC as the evaluation metric
|
||||
result = corr(test_labels, prediction);
|
||||
[ ~, ~, ~, ccc, ~, ~ ] = evaluate_regression_results( prediction, test_labels );
|
||||
|
||||
result = ccc;
|
||||
% using the average of CCC errors if different datasets are used
|
||||
if(~isfield(model, 'eval_ids'))
|
||||
result = corr(test_labels, prediction);
|
||||
[ ~, ~, ~, ccc, ~, ~ ] = evaluate_regression_results( prediction, test_labels );
|
||||
result = ccc;
|
||||
else
|
||||
eval_ids = unique(model.eval_ids)';
|
||||
ccc = 0;
|
||||
fprintf('CCC: ');
|
||||
for i=eval_ids
|
||||
[ ~, ~, ~, ccc_curr, ~, ~ ] = evaluate_regression_results( prediction(model.eval_ids == i), test_labels(model.eval_ids == i) );
|
||||
ccc = ccc + ccc_curr;
|
||||
fprintf('%.3f ', ccc_curr);
|
||||
end
|
||||
ccc = ccc / numel(eval_ids);
|
||||
fprintf('mean : %.3f\n', ccc);
|
||||
result = ccc;
|
||||
end
|
||||
|
||||
if(isnan(result))
|
||||
result = 0;
|
||||
|
||||
@@ -29,10 +29,23 @@ function [result, prediction] = svr_test_linear_shift(test_labels, test_samples,
|
||||
|
||||
% using the average of RMS errors
|
||||
% result = mean(sqrt(mean((prediction - test_labels).^2)));
|
||||
result = corr(test_labels, prediction);
|
||||
[ ~, ~, ~, ccc, ~, ~ ] = evaluate_regression_results( prediction, test_labels );
|
||||
|
||||
result = ccc;
|
||||
if(~isfield(model, 'eval_ids'))
|
||||
result = corr(test_labels, prediction);
|
||||
[ ~, ~, ~, ccc, ~, ~ ] = evaluate_regression_results( prediction, test_labels );
|
||||
result = ccc;
|
||||
else
|
||||
eval_ids = unique(model.eval_ids)';
|
||||
ccc = 0;
|
||||
fprintf('CCC: ');
|
||||
for i=eval_ids
|
||||
[ ~, ~, ~, ccc_curr, ~, ~ ] = evaluate_regression_results( prediction(model.eval_ids == i), test_labels(model.eval_ids == i) );
|
||||
ccc = ccc + ccc_curr;
|
||||
fprintf('%.3f ', ccc_curr);
|
||||
end
|
||||
ccc = ccc / numel(eval_ids);
|
||||
fprintf('mean : %.3f\n', ccc);
|
||||
result = ccc;
|
||||
end
|
||||
|
||||
if(isnan(result))
|
||||
result = 0;
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
function [result, prediction] = svr_test_linear_shift_fancy(test_labels, test_samples, model)
|
||||
|
||||
prediction = test_samples * model.w(1:end-1)' + model.w(end);
|
||||
% prediction = predict(test_labels, test_samples, model);
|
||||
|
||||
prediction(~model.success) = 0;
|
||||
|
||||
if(model.cutoff >= 0)
|
||||
% perform shifting here per person
|
||||
users = unique(model.vid_ids);
|
||||
|
||||
for i=1:numel(users)
|
||||
|
||||
preds_user = prediction(strcmp(model.vid_ids, users(i)));
|
||||
sorted = sort(preds_user);
|
||||
|
||||
% alternative, move to histograms and pick the highest one
|
||||
|
||||
shift = sorted(round(end*model.cutoff)+1);
|
||||
|
||||
prediction(strcmp(model.vid_ids, users(i))) = preds_user - shift;
|
||||
|
||||
end
|
||||
end
|
||||
|
||||
% Cap the prediction as well
|
||||
prediction(prediction<0)=0;
|
||||
prediction(prediction>5)=5;
|
||||
|
||||
% using the average of RMS errors
|
||||
% result = mean(sqrt(mean((prediction - test_labels).^2)));
|
||||
result = corr(test_labels, prediction);
|
||||
[ ~, ~, ~, ccc, ~, ~ ] = evaluate_regression_results( prediction, test_labels );
|
||||
|
||||
result = ccc;
|
||||
|
||||
if(isnan(result))
|
||||
result = 0;
|
||||
end
|
||||
|
||||
end
|
||||
@@ -5,4 +5,9 @@ function [model] = svr_train_linear(train_labels, train_samples, hyper)
|
||||
comm = sprintf('-s 11 -B 1 -p %.10f -c %.10f -q', hyper.p, hyper.c);
|
||||
model = train(train_labels, train_samples, comm);
|
||||
model.success = hyper.success;
|
||||
|
||||
if(isfield(hyper, 'eval_ids'))
|
||||
model.eval_ids = hyper.eval_ids;
|
||||
end
|
||||
|
||||
end
|
||||
@@ -1,4 +1,4 @@
|
||||
function [model] = svr_train_linear_shift_fancy(train_labels, train_samples, hyper)
|
||||
function [model] = svr_train_linear_shift(train_labels, train_samples, hyper)
|
||||
|
||||
% Change to your downloaded location
|
||||
addpath('C:\liblinear\matlab')
|
||||
@@ -54,4 +54,9 @@ function [model] = svr_train_linear_shift_fancy(train_labels, train_samples, hyp
|
||||
model.cutoff = cutoffs(best_id);
|
||||
model.vid_ids = hyper.vid_ids;
|
||||
model.success = hyper.success;
|
||||
|
||||
if(isfield(hyper, 'eval_ids'))
|
||||
model.eval_ids = hyper.eval_ids;
|
||||
end
|
||||
|
||||
end
|
||||
Reference in New Issue
Block a user