update loss_func
This commit is contained in:
parent
ac500c58ad
commit
b0ad80f0b9
@ -29,12 +29,13 @@
|
|||||||
|
|
||||||
gctl::loss_func::loss_func()
|
gctl::loss_func::loss_func()
|
||||||
{
|
{
|
||||||
uncer_type_ = 0;
|
init_ = false;
|
||||||
|
tnum_ = 0;
|
||||||
|
ntype_ = L2;
|
||||||
}
|
}
|
||||||
|
|
||||||
gctl::loss_func::loss_func(const array<double> &tar, norm_type_e n_type)
|
gctl::loss_func::loss_func(const array<double> &tar, norm_type_e n_type)
|
||||||
{
|
{
|
||||||
uncer_type_ = 0;
|
|
||||||
init(tar, n_type);
|
init(tar, n_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -42,59 +43,61 @@ gctl::loss_func::~loss_func(){}
|
|||||||
|
|
||||||
void gctl::loss_func::init(const array<double> &tar, norm_type_e n_type)
|
void gctl::loss_func::init(const array<double> &tar, norm_type_e n_type)
|
||||||
{
|
{
|
||||||
tar_num_ = tar.size();
|
tnum_ = tar.size();
|
||||||
|
diff_.resize(tnum_);
|
||||||
|
us_.resize(tnum_, 1.0);
|
||||||
tars_ = tar;
|
tars_ = tar;
|
||||||
norm_type_ = n_type;
|
ntype_ = n_type;
|
||||||
|
init_ = true;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
void gctl::loss_func::set_uncertainty(double uncer)
|
void gctl::loss_func::set_uncertainty(double uncer)
|
||||||
{
|
{
|
||||||
uncer_type_ = 1;
|
if (!init_) throw std::runtime_error("[gctl::loss_func] Not initialized.");
|
||||||
uncer_ = uncer;
|
us_.resize(tnum_, uncer);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
void gctl::loss_func::set_uncertainty(const array<double> &uncer)
|
void gctl::loss_func::set_uncertainty(const array<double> &uncer)
|
||||||
{
|
{
|
||||||
uncer_type_ = 2;
|
if (!init_) throw std::runtime_error("[gctl::loss_func] Not initialized.");
|
||||||
uncers_ = uncer;
|
if (uncer.size() != tnum_) throw std::runtime_error("[gctl::loss_func] Invalid array size.");
|
||||||
|
us_ = uncer;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
double gctl::loss_func::get_loss()
|
double gctl::loss_func::evaluate(const array<double> &x, array<double> &g)
|
||||||
{
|
{
|
||||||
double l = loss_;
|
if (!init_) throw std::runtime_error("[gctl::loss_func] Not initialized.");
|
||||||
loss_ = 0.0;
|
if (x.size() != tnum_) throw std::runtime_error("[gctl::loss_func] Invalid array size.");
|
||||||
return l;
|
|
||||||
|
for (size_t i = 0; i < tnum_; i++)
|
||||||
|
{
|
||||||
|
diff_[i] = (x[i] - tars_[i])/us_[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
double loss = 0.0;
|
||||||
|
g.resize(tnum_);
|
||||||
|
|
||||||
|
if (ntype_ == L1)
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < tnum_; i++)
|
||||||
|
{
|
||||||
|
loss += fabs(diff_[i]);
|
||||||
|
if (diff_[i] >= 0.0) g[i] = 1.0/(us_[i]*tnum_);
|
||||||
|
else g[i] = -1.0/(us_[i]*tnum_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if (ntype_ == L2)
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < tnum_; i++)
|
||||||
|
{
|
||||||
|
loss += diff_[i]*diff_[i];
|
||||||
|
g[i] = 2.0*diff_[i]/(us_[i]*tnum_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else throw std::runtime_error("[gctl::loss_func] Invalid measurement type.");
|
||||||
|
|
||||||
|
return loss/tnum_;
|
||||||
}
|
}
|
||||||
|
|
||||||
double gctl::loss_func::evaluate(double inp, int id)
|
|
||||||
{
|
|
||||||
double val = (inp - tars_[id]);
|
|
||||||
if (uncer_type_ == 1) val /= uncer_;
|
|
||||||
else if (uncer_type_ == 2) val /= uncers_[id];
|
|
||||||
|
|
||||||
if (norm_type_ == L1) val = fabs(val);
|
|
||||||
if (norm_type_ == L2) val = val*val;
|
|
||||||
|
|
||||||
loss_ += val;
|
|
||||||
return val/tar_num_;
|
|
||||||
}
|
|
||||||
|
|
||||||
double gctl::loss_func::gradient(double inp, int id)
|
|
||||||
{
|
|
||||||
double c;
|
|
||||||
if (uncer_type_ == 1) c = uncer_;
|
|
||||||
else if (uncer_type_ == 2) c = uncers_[id];
|
|
||||||
|
|
||||||
double val = (inp - tars_[id]);
|
|
||||||
if (norm_type_ == L1 && val >= 0) val = 1.0;
|
|
||||||
if (norm_type_ == L1 && val < 0) val = -1.0;
|
|
||||||
if (norm_type_ == L2) val = 2.0*val;
|
|
||||||
|
|
||||||
if (norm_type_ == L1 && uncer_type_ != 0) val /= c;
|
|
||||||
else if (norm_type_ == L2 && uncer_type_ != 0) val /= (c*c);
|
|
||||||
|
|
||||||
return val/tar_num_;
|
|
||||||
}
|
|
@ -43,18 +43,14 @@ namespace gctl
|
|||||||
void init(const array<double> &tar, norm_type_e n_type);
|
void init(const array<double> &tar, norm_type_e n_type);
|
||||||
void set_uncertainty(double uncer);
|
void set_uncertainty(double uncer);
|
||||||
void set_uncertainty(const array<double> &uncer);
|
void set_uncertainty(const array<double> &uncer);
|
||||||
double get_loss();
|
double evaluate(const array<double> &x, array<double> &g);
|
||||||
double evaluate(double inp, int id);
|
|
||||||
double gradient(double inp, int id);
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
//unsigned int counter_;
|
bool init_;
|
||||||
unsigned int tar_num_;
|
unsigned int tnum_;
|
||||||
int uncer_type_;
|
norm_type_e ntype_;
|
||||||
double uncer_, loss_;
|
array<double> tars_, diff_;
|
||||||
norm_type_e norm_type_;
|
array<double> us_;
|
||||||
array<double> tars_;
|
|
||||||
array<double> uncers_;
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user