update loss_func
This commit is contained in:
parent
ac500c58ad
commit
b0ad80f0b9
@ -29,12 +29,13 @@
|
||||
|
||||
gctl::loss_func::loss_func()
|
||||
{
|
||||
uncer_type_ = 0;
|
||||
init_ = false;
|
||||
tnum_ = 0;
|
||||
ntype_ = L2;
|
||||
}
|
||||
|
||||
gctl::loss_func::loss_func(const array<double> &tar, norm_type_e n_type)
|
||||
{
|
||||
uncer_type_ = 0;
|
||||
init(tar, n_type);
|
||||
}
|
||||
|
||||
@ -42,59 +43,61 @@ gctl::loss_func::~loss_func(){}
|
||||
|
||||
void gctl::loss_func::init(const array<double> &tar, norm_type_e n_type)
|
||||
{
|
||||
tar_num_ = tar.size();
|
||||
tnum_ = tar.size();
|
||||
diff_.resize(tnum_);
|
||||
us_.resize(tnum_, 1.0);
|
||||
tars_ = tar;
|
||||
norm_type_ = n_type;
|
||||
ntype_ = n_type;
|
||||
init_ = true;
|
||||
return;
|
||||
}
|
||||
|
||||
void gctl::loss_func::set_uncertainty(double uncer)
|
||||
{
|
||||
uncer_type_ = 1;
|
||||
uncer_ = uncer;
|
||||
if (!init_) throw std::runtime_error("[gctl::loss_func] Not initialized.");
|
||||
us_.resize(tnum_, uncer);
|
||||
return;
|
||||
}
|
||||
|
||||
void gctl::loss_func::set_uncertainty(const array<double> &uncer)
|
||||
{
|
||||
uncer_type_ = 2;
|
||||
uncers_ = uncer;
|
||||
if (!init_) throw std::runtime_error("[gctl::loss_func] Not initialized.");
|
||||
if (uncer.size() != tnum_) throw std::runtime_error("[gctl::loss_func] Invalid array size.");
|
||||
us_ = uncer;
|
||||
return;
|
||||
}
|
||||
|
||||
double gctl::loss_func::get_loss()
|
||||
double gctl::loss_func::evaluate(const array<double> &x, array<double> &g)
|
||||
{
|
||||
double l = loss_;
|
||||
loss_ = 0.0;
|
||||
return l;
|
||||
if (!init_) throw std::runtime_error("[gctl::loss_func] Not initialized.");
|
||||
if (x.size() != tnum_) throw std::runtime_error("[gctl::loss_func] Invalid array size.");
|
||||
|
||||
for (size_t i = 0; i < tnum_; i++)
|
||||
{
|
||||
diff_[i] = (x[i] - tars_[i])/us_[i];
|
||||
}
|
||||
|
||||
double loss = 0.0;
|
||||
g.resize(tnum_);
|
||||
|
||||
if (ntype_ == L1)
|
||||
{
|
||||
for (size_t i = 0; i < tnum_; i++)
|
||||
{
|
||||
loss += fabs(diff_[i]);
|
||||
if (diff_[i] >= 0.0) g[i] = 1.0/(us_[i]*tnum_);
|
||||
else g[i] = -1.0/(us_[i]*tnum_);
|
||||
}
|
||||
}
|
||||
else if (ntype_ == L2)
|
||||
{
|
||||
for (size_t i = 0; i < tnum_; i++)
|
||||
{
|
||||
loss += diff_[i]*diff_[i];
|
||||
g[i] = 2.0*diff_[i]/(us_[i]*tnum_);
|
||||
}
|
||||
}
|
||||
else throw std::runtime_error("[gctl::loss_func] Invalid measurement type.");
|
||||
|
||||
return loss/tnum_;
|
||||
}
|
||||
|
||||
double gctl::loss_func::evaluate(double inp, int id)
|
||||
{
|
||||
double val = (inp - tars_[id]);
|
||||
if (uncer_type_ == 1) val /= uncer_;
|
||||
else if (uncer_type_ == 2) val /= uncers_[id];
|
||||
|
||||
if (norm_type_ == L1) val = fabs(val);
|
||||
if (norm_type_ == L2) val = val*val;
|
||||
|
||||
loss_ += val;
|
||||
return val/tar_num_;
|
||||
}
|
||||
|
||||
double gctl::loss_func::gradient(double inp, int id)
|
||||
{
|
||||
double c;
|
||||
if (uncer_type_ == 1) c = uncer_;
|
||||
else if (uncer_type_ == 2) c = uncers_[id];
|
||||
|
||||
double val = (inp - tars_[id]);
|
||||
if (norm_type_ == L1 && val >= 0) val = 1.0;
|
||||
if (norm_type_ == L1 && val < 0) val = -1.0;
|
||||
if (norm_type_ == L2) val = 2.0*val;
|
||||
|
||||
if (norm_type_ == L1 && uncer_type_ != 0) val /= c;
|
||||
else if (norm_type_ == L2 && uncer_type_ != 0) val /= (c*c);
|
||||
|
||||
return val/tar_num_;
|
||||
}
|
@ -43,18 +43,14 @@ namespace gctl
|
||||
void init(const array<double> &tar, norm_type_e n_type);
|
||||
void set_uncertainty(double uncer);
|
||||
void set_uncertainty(const array<double> &uncer);
|
||||
double get_loss();
|
||||
double evaluate(double inp, int id);
|
||||
double gradient(double inp, int id);
|
||||
double evaluate(const array<double> &x, array<double> &g);
|
||||
|
||||
private:
|
||||
//unsigned int counter_;
|
||||
unsigned int tar_num_;
|
||||
int uncer_type_;
|
||||
double uncer_, loss_;
|
||||
norm_type_e norm_type_;
|
||||
array<double> tars_;
|
||||
array<double> uncers_;
|
||||
bool init_;
|
||||
unsigned int tnum_;
|
||||
norm_type_e ntype_;
|
||||
array<double> tars_, diff_;
|
||||
array<double> us_;
|
||||
};
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user