tmp update
This commit is contained in:
parent
b0ad80f0b9
commit
9aedd1b8e8
@ -30,25 +30,33 @@
|
|||||||
gctl::loss_func::loss_func()
|
gctl::loss_func::loss_func()
|
||||||
{
|
{
|
||||||
init_ = false;
|
init_ = false;
|
||||||
|
eps_ = 1e-8;
|
||||||
tnum_ = 0;
|
tnum_ = 0;
|
||||||
ntype_ = L2;
|
ntype_ = L2;
|
||||||
}
|
}
|
||||||
|
|
||||||
gctl::loss_func::loss_func(const array<double> &tar, norm_type_e n_type)
|
gctl::loss_func::loss_func(const array<double> &tar, norm_type_e n_type, double p, double eps)
|
||||||
{
|
{
|
||||||
init(tar, n_type);
|
init(tar, n_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
gctl::loss_func::~loss_func(){}
|
gctl::loss_func::~loss_func(){}
|
||||||
|
|
||||||
void gctl::loss_func::init(const array<double> &tar, norm_type_e n_type)
|
void gctl::loss_func::init(const array<double> &tar, norm_type_e n_type, double p, double eps)
|
||||||
{
|
{
|
||||||
|
if (p < 1) throw std::runtime_error("[gctl::loss_func] Invalid power number.");
|
||||||
|
if (eps <= 0) throw std::runtime_error("[gctl::loss_func] Invalid epsilon value.");
|
||||||
|
|
||||||
|
init_ = true;
|
||||||
|
|
||||||
tnum_ = tar.size();
|
tnum_ = tar.size();
|
||||||
diff_.resize(tnum_);
|
diff_.resize(tnum_);
|
||||||
us_.resize(tnum_, 1.0);
|
us_.resize(tnum_, 1.0);
|
||||||
tars_ = tar;
|
tars_ = tar;
|
||||||
|
|
||||||
ntype_ = n_type;
|
ntype_ = n_type;
|
||||||
init_ = true;
|
eps_ = eps;
|
||||||
|
p_ = p;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -97,6 +105,14 @@ double gctl::loss_func::evaluate(const array<double> &x, array<double> &g)
|
|||||||
g[i] = 2.0*diff_[i]/(us_[i]*tnum_);
|
g[i] = 2.0*diff_[i]/(us_[i]*tnum_);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
else if (ntype_ == Lp)
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < tnum_; i++)
|
||||||
|
{
|
||||||
|
loss += pow(diff_[i]*diff_[i] + eps_*eps_, 0.5*p_);
|
||||||
|
g[i] = p_*pow(diff_[i]*diff_[i] + eps_*eps_, 0.5*p_ - 1)*diff_[i]/(us_[i]*tnum_);
|
||||||
|
}
|
||||||
|
}
|
||||||
else throw std::runtime_error("[gctl::loss_func] Invalid measurement type.");
|
else throw std::runtime_error("[gctl::loss_func] Invalid measurement type.");
|
||||||
|
|
||||||
return loss/tnum_;
|
return loss/tnum_;
|
||||||
|
@ -37,16 +37,17 @@ namespace gctl
|
|||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
loss_func();
|
loss_func();
|
||||||
loss_func(const array<double> &tar, norm_type_e n_type);
|
loss_func(const array<double> &tar, norm_type_e n_type, double p = 2.0, double eps = 1e-16);
|
||||||
virtual ~loss_func();
|
virtual ~loss_func();
|
||||||
|
|
||||||
void init(const array<double> &tar, norm_type_e n_type);
|
void init(const array<double> &tar, norm_type_e n_type, double p = 2.0, double eps = 1e-16);
|
||||||
void set_uncertainty(double uncer);
|
void set_uncertainty(double uncer);
|
||||||
void set_uncertainty(const array<double> &uncer);
|
void set_uncertainty(const array<double> &uncer);
|
||||||
double evaluate(const array<double> &x, array<double> &g);
|
double evaluate(const array<double> &x, array<double> &g);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool init_;
|
bool init_;
|
||||||
|
double eps_, p_;
|
||||||
unsigned int tnum_;
|
unsigned int tnum_;
|
||||||
norm_type_e ntype_;
|
norm_type_e ntype_;
|
||||||
array<double> tars_, diff_;
|
array<double> tars_, diff_;
|
||||||
|
Loading…
Reference in New Issue
Block a user