tmp update
This commit is contained in:
parent
094bbfa70d
commit
892c460b44
@ -31,6 +31,7 @@ gctl::loss_func::loss_func()
|
||||
{
|
||||
init_ = false;
|
||||
eps_ = 1e-16;
|
||||
loss_ = 0.0;
|
||||
tnum_ = 0;
|
||||
ntype_ = L2;
|
||||
}
|
||||
@ -117,3 +118,33 @@ double gctl::loss_func::evaluate(const array<double> &x, array<double> &g)
|
||||
|
||||
return loss/tnum_;
|
||||
}
|
||||
|
||||
double gctl::loss_func::evaluate(double inp, int id)
|
||||
{
|
||||
double val = (inp - tars_[id])/us_[id];
|
||||
|
||||
if (ntype_ == L1) val = fabs(val);
|
||||
if (ntype_ == L2) val = val*val;
|
||||
|
||||
loss_ += val;
|
||||
|
||||
return val/tnum_;
|
||||
}
|
||||
|
||||
double gctl::loss_func::get_loss()
|
||||
{
|
||||
double l = loss_;
|
||||
loss_ = 0.0;
|
||||
|
||||
return l;
|
||||
}
|
||||
|
||||
double gctl::loss_func::gradient(double inp, int id)
|
||||
{
|
||||
double val;
|
||||
if (ntype_ == L1 && val >= 0) val = 1.0/(us_[id]*tnum_);
|
||||
if (ntype_ == L1 && val < 0) val = -1.0/(us_[id]*tnum_);
|
||||
if (ntype_ == L2) val = 2.0*(inp - tars_[id])/(us_[id]*us_[id]*tnum_);
|
||||
|
||||
return val;
|
||||
}
|
@ -43,11 +43,16 @@ namespace gctl
|
||||
void init(const array<double> &tar, norm_type_e n_type, double p = 2.0, double eps = 1e-16);
|
||||
void set_uncertainty(double uncer);
|
||||
void set_uncertainty(const array<double> &uncer);
|
||||
|
||||
double evaluate(double inp, int id);
|
||||
double evaluate(const array<double> &x, array<double> &g);
|
||||
|
||||
double get_loss();
|
||||
double gradient(double inp, int id);
|
||||
|
||||
private:
|
||||
bool init_;
|
||||
double eps_, p_;
|
||||
double loss_, eps_, p_;
|
||||
unsigned int tnum_;
|
||||
norm_type_e ntype_;
|
||||
array<double> tars_, diff_;
|
||||
|
Loading…
Reference in New Issue
Block a user