update loss_func

This commit is contained in:
张壹 2024-10-07 08:37:31 +08:00
parent ac500c58ad
commit b0ad80f0b9
2 changed files with 51 additions and 52 deletions

View File

@ -29,12 +29,13 @@
gctl::loss_func::loss_func()
{
uncer_type_ = 0;
init_ = false;
tnum_ = 0;
ntype_ = L2;
}
gctl::loss_func::loss_func(const array<double> &tar, norm_type_e n_type)
{
uncer_type_ = 0;
init(tar, n_type);
}
@ -42,59 +43,61 @@ gctl::loss_func::~loss_func(){}
void gctl::loss_func::init(const array<double> &tar, norm_type_e n_type)
{
tar_num_ = tar.size();
tnum_ = tar.size();
diff_.resize(tnum_);
us_.resize(tnum_, 1.0);
tars_ = tar;
norm_type_ = n_type;
ntype_ = n_type;
init_ = true;
return;
}
void gctl::loss_func::set_uncertainty(double uncer)
{
uncer_type_ = 1;
uncer_ = uncer;
if (!init_) throw std::runtime_error("[gctl::loss_func] Not initialized.");
us_.resize(tnum_, uncer);
return;
}
void gctl::loss_func::set_uncertainty(const array<double> &uncer)
{
uncer_type_ = 2;
uncers_ = uncer;
if (!init_) throw std::runtime_error("[gctl::loss_func] Not initialized.");
if (uncer.size() != tnum_) throw std::runtime_error("[gctl::loss_func] Invalid array size.");
us_ = uncer;
return;
}
double gctl::loss_func::get_loss()
double gctl::loss_func::evaluate(const array<double> &x, array<double> &g)
{
double l = loss_;
loss_ = 0.0;
return l;
if (!init_) throw std::runtime_error("[gctl::loss_func] Not initialized.");
if (x.size() != tnum_) throw std::runtime_error("[gctl::loss_func] Invalid array size.");
for (size_t i = 0; i < tnum_; i++)
{
diff_[i] = (x[i] - tars_[i])/us_[i];
}
double gctl::loss_func::evaluate(double inp, int id)
double loss = 0.0;
g.resize(tnum_);
if (ntype_ == L1)
{
double val = (inp - tars_[id]);
if (uncer_type_ == 1) val /= uncer_;
else if (uncer_type_ == 2) val /= uncers_[id];
if (norm_type_ == L1) val = fabs(val);
if (norm_type_ == L2) val = val*val;
loss_ += val;
return val/tar_num_;
}
double gctl::loss_func::gradient(double inp, int id)
for (size_t i = 0; i < tnum_; i++)
{
double c;
if (uncer_type_ == 1) c = uncer_;
else if (uncer_type_ == 2) c = uncers_[id];
double val = (inp - tars_[id]);
if (norm_type_ == L1 && val >= 0) val = 1.0;
if (norm_type_ == L1 && val < 0) val = -1.0;
if (norm_type_ == L2) val = 2.0*val;
if (norm_type_ == L1 && uncer_type_ != 0) val /= c;
else if (norm_type_ == L2 && uncer_type_ != 0) val /= (c*c);
return val/tar_num_;
loss += fabs(diff_[i]);
if (diff_[i] >= 0.0) g[i] = 1.0/(us_[i]*tnum_);
else g[i] = -1.0/(us_[i]*tnum_);
}
}
else if (ntype_ == L2)
{
for (size_t i = 0; i < tnum_; i++)
{
loss += diff_[i]*diff_[i];
g[i] = 2.0*diff_[i]/(us_[i]*tnum_);
}
}
else throw std::runtime_error("[gctl::loss_func] Invalid measurement type.");
return loss/tnum_;
}

View File

@ -43,18 +43,14 @@ namespace gctl
void init(const array<double> &tar, norm_type_e n_type);
void set_uncertainty(double uncer);
void set_uncertainty(const array<double> &uncer);
double get_loss();
double evaluate(double inp, int id);
double gradient(double inp, int id);
double evaluate(const array<double> &x, array<double> &g);
private:
//unsigned int counter_;
unsigned int tar_num_;
int uncer_type_;
double uncer_, loss_;
norm_type_e norm_type_;
array<double> tars_;
array<double> uncers_;
bool init_;
unsigned int tnum_;
norm_type_e ntype_;
array<double> tars_, diff_;
array<double> us_;
};
}