mirror of
https://gitcode.com/gh_mirrors/ope/OpenFace.git
synced 2026-05-11 01:42:46 +00:00
277 lines
7.7 KiB
C++
277 lines
7.7 KiB
C++
// Copyright (C) 2011 Davis E. King (davis@dlib.net)
|
|
// License: Boost Software License See LICENSE.txt for the full license.
|
|
#ifndef DLIB_STRUCTURAL_ASSiGNMENT_TRAINER_Hh_
|
|
#define DLIB_STRUCTURAL_ASSiGNMENT_TRAINER_Hh_
|
|
|
|
#include "structural_assignment_trainer_abstract.h"
|
|
#include "../algs.h"
|
|
#include "../optimization.h"
|
|
#include "structural_svm_assignment_problem.h"
|
|
#include "num_nonnegative_weights.h"
|
|
|
|
|
|
namespace dlib
|
|
{
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
template <
|
|
typename feature_extractor
|
|
>
|
|
class structural_assignment_trainer
|
|
{
|
|
public:
|
|
typedef typename feature_extractor::lhs_element lhs_element;
|
|
typedef typename feature_extractor::rhs_element rhs_element;
|
|
typedef std::pair<std::vector<lhs_element>, std::vector<rhs_element> > sample_type;
|
|
typedef std::vector<long> label_type;
|
|
typedef assignment_function<feature_extractor> trained_function_type;
|
|
|
|
structural_assignment_trainer (
|
|
)
|
|
{
|
|
set_defaults();
|
|
}
|
|
|
|
explicit structural_assignment_trainer (
|
|
const feature_extractor& fe_
|
|
) : fe(fe_)
|
|
{
|
|
set_defaults();
|
|
}
|
|
|
|
const feature_extractor& get_feature_extractor (
|
|
) const { return fe; }
|
|
|
|
void set_num_threads (
|
|
unsigned long num
|
|
)
|
|
{
|
|
num_threads = num;
|
|
}
|
|
|
|
unsigned long get_num_threads (
|
|
) const
|
|
{
|
|
return num_threads;
|
|
}
|
|
|
|
void set_epsilon (
|
|
double eps_
|
|
)
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(eps_ > 0,
|
|
"\t void structural_assignment_trainer::set_epsilon()"
|
|
<< "\n\t eps_ must be greater than 0"
|
|
<< "\n\t eps_: " << eps_
|
|
<< "\n\t this: " << this
|
|
);
|
|
|
|
eps = eps_;
|
|
}
|
|
|
|
double get_epsilon (
|
|
) const { return eps; }
|
|
|
|
void set_max_cache_size (
|
|
unsigned long max_size
|
|
)
|
|
{
|
|
max_cache_size = max_size;
|
|
}
|
|
|
|
unsigned long get_max_cache_size (
|
|
) const
|
|
{
|
|
return max_cache_size;
|
|
}
|
|
|
|
void be_verbose (
|
|
)
|
|
{
|
|
verbose = true;
|
|
}
|
|
|
|
void be_quiet (
|
|
)
|
|
{
|
|
verbose = false;
|
|
}
|
|
|
|
void set_oca (
|
|
const oca& item
|
|
)
|
|
{
|
|
solver = item;
|
|
}
|
|
|
|
const oca get_oca (
|
|
) const
|
|
{
|
|
return solver;
|
|
}
|
|
|
|
void set_c (
|
|
double C_
|
|
)
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(C_ > 0,
|
|
"\t void structural_assignment_trainer::set_c()"
|
|
<< "\n\t C_ must be greater than 0"
|
|
<< "\n\t C_: " << C_
|
|
<< "\n\t this: " << this
|
|
);
|
|
|
|
C = C_;
|
|
}
|
|
|
|
double get_c (
|
|
) const
|
|
{
|
|
return C;
|
|
}
|
|
|
|
bool forces_assignment(
|
|
) const { return force_assignment; }
|
|
|
|
void set_forces_assignment (
|
|
bool new_value
|
|
)
|
|
{
|
|
force_assignment = new_value;
|
|
}
|
|
|
|
void set_loss_per_false_association (
|
|
double loss
|
|
)
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(loss > 0,
|
|
"\t void structural_assignment_trainer::set_loss_per_false_association(loss)"
|
|
<< "\n\t Invalid inputs were given to this function "
|
|
<< "\n\t loss: " << loss
|
|
<< "\n\t this: " << this
|
|
);
|
|
|
|
loss_per_false_association = loss;
|
|
}
|
|
|
|
double get_loss_per_false_association (
|
|
) const
|
|
{
|
|
return loss_per_false_association;
|
|
}
|
|
|
|
void set_loss_per_missed_association (
|
|
double loss
|
|
)
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(loss > 0,
|
|
"\t void structural_assignment_trainer::set_loss_per_missed_association(loss)"
|
|
<< "\n\t Invalid inputs were given to this function "
|
|
<< "\n\t loss: " << loss
|
|
<< "\n\t this: " << this
|
|
);
|
|
|
|
loss_per_missed_association = loss;
|
|
}
|
|
|
|
double get_loss_per_missed_association (
|
|
) const
|
|
{
|
|
return loss_per_missed_association;
|
|
}
|
|
|
|
const assignment_function<feature_extractor> train (
|
|
const std::vector<sample_type>& samples,
|
|
const std::vector<label_type>& labels
|
|
) const
|
|
{
|
|
// make sure requires clause is not broken
|
|
#ifdef ENABLE_ASSERTS
|
|
if (force_assignment)
|
|
{
|
|
DLIB_ASSERT(is_forced_assignment_problem(samples, labels),
|
|
"\t assignment_function structural_assignment_trainer::train()"
|
|
<< "\n\t invalid inputs were given to this function"
|
|
<< "\n\t is_forced_assignment_problem(samples,labels): " << is_forced_assignment_problem(samples,labels)
|
|
<< "\n\t is_assignment_problem(samples,labels): " << is_assignment_problem(samples,labels)
|
|
<< "\n\t is_learning_problem(samples,labels): " << is_learning_problem(samples,labels)
|
|
);
|
|
}
|
|
else
|
|
{
|
|
DLIB_ASSERT(is_assignment_problem(samples, labels),
|
|
"\t assignment_function structural_assignment_trainer::train()"
|
|
<< "\n\t invalid inputs were given to this function"
|
|
<< "\n\t is_assignment_problem(samples,labels): " << is_assignment_problem(samples,labels)
|
|
<< "\n\t is_learning_problem(samples,labels): " << is_learning_problem(samples,labels)
|
|
);
|
|
}
|
|
#endif
|
|
|
|
|
|
|
|
structural_svm_assignment_problem<feature_extractor> prob(samples,labels, fe, force_assignment, num_threads,
|
|
loss_per_false_association, loss_per_missed_association);
|
|
|
|
if (verbose)
|
|
prob.be_verbose();
|
|
|
|
prob.set_c(C);
|
|
prob.set_epsilon(eps);
|
|
prob.set_max_cache_size(max_cache_size);
|
|
|
|
matrix<double,0,1> weights;
|
|
|
|
// Take the min here because we want to prevent the user from accidentally
|
|
// forcing the bias term to be non-negative.
|
|
const unsigned long num_nonneg = std::min(fe.num_features(),num_nonnegative_weights(fe));
|
|
solver(prob, weights, num_nonneg);
|
|
|
|
const double bias = weights(weights.size()-1);
|
|
return assignment_function<feature_extractor>(colm(weights,0,weights.size()-1), bias,fe,force_assignment);
|
|
|
|
}
|
|
|
|
|
|
private:
|
|
|
|
bool force_assignment;
|
|
double C;
|
|
oca solver;
|
|
double eps;
|
|
bool verbose;
|
|
unsigned long num_threads;
|
|
unsigned long max_cache_size;
|
|
double loss_per_false_association;
|
|
double loss_per_missed_association;
|
|
|
|
void set_defaults ()
|
|
{
|
|
force_assignment = false;
|
|
C = 100;
|
|
verbose = false;
|
|
eps = 0.01;
|
|
num_threads = 2;
|
|
max_cache_size = 5;
|
|
loss_per_false_association = 1;
|
|
loss_per_missed_association = 1;
|
|
}
|
|
|
|
feature_extractor fe;
|
|
};
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
}
|
|
|
|
#endif // DLIB_STRUCTURAL_ASSiGNMENT_TRAINER_Hh_
|
|
|
|
|
|
|
|
|