tmp update

This commit is contained in:
张壹 2024-10-08 11:25:52 +08:00
parent 094bbfa70d
commit 892c460b44
2 changed files with 37 additions and 1 deletions

View File

@ -31,6 +31,7 @@ gctl::loss_func::loss_func()
{ {
init_ = false; init_ = false;
eps_ = 1e-16; eps_ = 1e-16;
loss_ = 0.0;
tnum_ = 0; tnum_ = 0;
ntype_ = L2; ntype_ = L2;
} }
@ -117,3 +118,33 @@ double gctl::loss_func::evaluate(const array<double> &x, array<double> &g)
return loss/tnum_; return loss/tnum_;
} }
double gctl::loss_func::evaluate(double inp, int id)
{
double val = (inp - tars_[id])/us_[id];
if (ntype_ == L1) val = fabs(val);
if (ntype_ == L2) val = val*val;
loss_ += val;
return val/tnum_;
}
double gctl::loss_func::get_loss()
{
double l = loss_;
loss_ = 0.0;
return l;
}
double gctl::loss_func::gradient(double inp, int id)
{
double val;
if (ntype_ == L1 && val >= 0) val = 1.0/(us_[id]*tnum_);
if (ntype_ == L1 && val < 0) val = -1.0/(us_[id]*tnum_);
if (ntype_ == L2) val = 2.0*(inp - tars_[id])/(us_[id]*us_[id]*tnum_);
return val;
}

View File

@ -43,11 +43,16 @@ namespace gctl
void init(const array<double> &tar, norm_type_e n_type, double p = 2.0, double eps = 1e-16); void init(const array<double> &tar, norm_type_e n_type, double p = 2.0, double eps = 1e-16);
void set_uncertainty(double uncer); void set_uncertainty(double uncer);
void set_uncertainty(const array<double> &uncer); void set_uncertainty(const array<double> &uncer);
double evaluate(double inp, int id);
double evaluate(const array<double> &x, array<double> &g); double evaluate(const array<double> &x, array<double> &g);
double get_loss();
double gradient(double inp, int id);
private: private:
bool init_; bool init_;
double eps_, p_; double loss_, eps_, p_;
unsigned int tnum_; unsigned int tnum_;
norm_type_e ntype_; norm_type_e ntype_;
array<double> tars_, diff_; array<double> tars_, diff_;