diff --git a/lib/optimization/loss_func.cpp b/lib/optimization/loss_func.cpp index 26464cf..ce890a5 100644 --- a/lib/optimization/loss_func.cpp +++ b/lib/optimization/loss_func.cpp @@ -29,12 +29,13 @@ gctl::loss_func::loss_func() { - uncer_type_ = 0; + init_ = false; + tnum_ = 0; + ntype_ = L2; } gctl::loss_func::loss_func(const array &tar, norm_type_e n_type) { - uncer_type_ = 0; init(tar, n_type); } @@ -42,59 +43,61 @@ gctl::loss_func::~loss_func(){} void gctl::loss_func::init(const array &tar, norm_type_e n_type) { - tar_num_ = tar.size(); + tnum_ = tar.size(); + diff_.resize(tnum_); + us_.resize(tnum_, 1.0); tars_ = tar; - norm_type_ = n_type; + ntype_ = n_type; + init_ = true; return; } void gctl::loss_func::set_uncertainty(double uncer) { - uncer_type_ = 1; - uncer_ = uncer; + if (!init_) throw std::runtime_error("[gctl::loss_func] Not initialized."); + us_.resize(tnum_, uncer); return; } void gctl::loss_func::set_uncertainty(const array &uncer) { - uncer_type_ = 2; - uncers_ = uncer; + if (!init_) throw std::runtime_error("[gctl::loss_func] Not initialized."); + if (uncer.size() != tnum_) throw std::runtime_error("[gctl::loss_func] Invalid array size."); + us_ = uncer; return; } -double gctl::loss_func::get_loss() +double gctl::loss_func::evaluate(const array &x, array &g) { - double l = loss_; - loss_ = 0.0; - return l; + if (!init_) throw std::runtime_error("[gctl::loss_func] Not initialized."); + if (x.size() != tnum_) throw std::runtime_error("[gctl::loss_func] Invalid array size."); + + for (size_t i = 0; i < tnum_; i++) + { + diff_[i] = (x[i] - tars_[i])/us_[i]; + } + + double loss = 0.0; + g.resize(tnum_); + + if (ntype_ == L1) + { + for (size_t i = 0; i < tnum_; i++) + { + loss += fabs(diff_[i]); + if (diff_[i] >= 0.0) g[i] = 1.0/(us_[i]*tnum_); + else g[i] = -1.0/(us_[i]*tnum_); + } + } + else if (ntype_ == L2) + { + for (size_t i = 0; i < tnum_; i++) + { + loss += diff_[i]*diff_[i]; + g[i] = 2.0*diff_[i]/(us_[i]*tnum_); + } + } + else throw std::runtime_error("[gctl::loss_func] Invalid measurement type."); + + return loss/tnum_; } - -double gctl::loss_func::evaluate(double inp, int id) -{ - double val = (inp - tars_[id]); - if (uncer_type_ == 1) val /= uncer_; - else if (uncer_type_ == 2) val /= uncers_[id]; - - if (norm_type_ == L1) val = fabs(val); - if (norm_type_ == L2) val = val*val; - - loss_ += val; - return val/tar_num_; -} - -double gctl::loss_func::gradient(double inp, int id) -{ - double c; - if (uncer_type_ == 1) c = uncer_; - else if (uncer_type_ == 2) c = uncers_[id]; - - double val = (inp - tars_[id]); - if (norm_type_ == L1 && val >= 0) val = 1.0; - if (norm_type_ == L1 && val < 0) val = -1.0; - if (norm_type_ == L2) val = 2.0*val; - - if (norm_type_ == L1 && uncer_type_ != 0) val /= c; - else if (norm_type_ == L2 && uncer_type_ != 0) val /= (c*c); - - return val/tar_num_; -} \ No newline at end of file diff --git a/lib/optimization/loss_func.h b/lib/optimization/loss_func.h index 2d4be1c..53784a0 100644 --- a/lib/optimization/loss_func.h +++ b/lib/optimization/loss_func.h @@ -43,18 +43,14 @@ namespace gctl void init(const array &tar, norm_type_e n_type); void set_uncertainty(double uncer); void set_uncertainty(const array &uncer); - double get_loss(); - double evaluate(double inp, int id); - double gradient(double inp, int id); + double evaluate(const array &x, array &g); private: - //unsigned int counter_; - unsigned int tar_num_; - int uncer_type_; - double uncer_, loss_; - norm_type_e norm_type_; - array tars_; - array uncers_; + bool init_; + unsigned int tnum_; + norm_type_e ntype_; + array tars_, diff_; + array us_; }; }