diff --git a/lib/optimization/loss_func.cpp b/lib/optimization/loss_func.cpp index ce890a5..c110f8b 100644 --- a/lib/optimization/loss_func.cpp +++ b/lib/optimization/loss_func.cpp @@ -30,25 +30,33 @@ gctl::loss_func::loss_func() { init_ = false; + eps_ = 1e-8; tnum_ = 0; ntype_ = L2; } -gctl::loss_func::loss_func(const array &tar, norm_type_e n_type) +gctl::loss_func::loss_func(const array &tar, norm_type_e n_type, double p, double eps) { init(tar, n_type); } gctl::loss_func::~loss_func(){} -void gctl::loss_func::init(const array &tar, norm_type_e n_type) +void gctl::loss_func::init(const array &tar, norm_type_e n_type, double p, double eps) { + if (p < 1) throw std::runtime_error("[gctl::loss_func] Invalid power number."); + if (eps <= 0) throw std::runtime_error("[gctl::loss_func] Invalid epsilon value."); + + init_ = true; + tnum_ = tar.size(); diff_.resize(tnum_); us_.resize(tnum_, 1.0); tars_ = tar; + ntype_ = n_type; - init_ = true; + eps_ = eps; + p_ = p; return; } @@ -97,6 +105,14 @@ double gctl::loss_func::evaluate(const array &x, array &g) g[i] = 2.0*diff_[i]/(us_[i]*tnum_); } } + else if (ntype_ == Lp) + { + for (size_t i = 0; i < tnum_; i++) + { + loss += pow(diff_[i]*diff_[i] + eps_*eps_, 0.5*p_); + g[i] = p_*pow(diff_[i]*diff_[i] + eps_*eps_, 0.5*p_ - 1)*diff_[i]/(us_[i]*tnum_); + } + } else throw std::runtime_error("[gctl::loss_func] Invalid measurement type."); return loss/tnum_; diff --git a/lib/optimization/loss_func.h b/lib/optimization/loss_func.h index 53784a0..b5c4070 100644 --- a/lib/optimization/loss_func.h +++ b/lib/optimization/loss_func.h @@ -37,16 +37,17 @@ namespace gctl { public: loss_func(); - loss_func(const array &tar, norm_type_e n_type); + loss_func(const array &tar, norm_type_e n_type, double p = 2.0, double eps = 1e-16); virtual ~loss_func(); - void init(const array &tar, norm_type_e n_type); + void init(const array &tar, norm_type_e n_type, double p = 2.0, double eps = 1e-16); void set_uncertainty(double uncer); void set_uncertainty(const array &uncer); double evaluate(const array &x, array &g); private: bool init_; + double eps_, p_; unsigned int tnum_; norm_type_e ntype_; array tars_, diff_;