diff --git a/lib/optimization/cmn_grad.cpp b/lib/optimization/cmn_grad.cpp index 3404c6e..2d54a6e 100644 --- a/lib/optimization/cmn_grad.cpp +++ b/lib/optimization/cmn_grad.cpp @@ -42,6 +42,7 @@ gctl::common_gradient::~common_gradient(){} void gctl::common_gradient::LCG_Ax(const array &x, array &ax) { matvec(t_, G_, x, NoTrans); + vecmul(t_, t_, w_); matvec(ax, G_, t_, Trans); return; } @@ -52,6 +53,17 @@ void gctl::common_gradient::set_solver(const lcg_para ¶) return; } +void gctl::common_gradient::set_weights(const _1d_array &w) +{ + if (w.size()!= Ln_) throw std::runtime_error("[gctl::common_gradient] Invalid array size."); + + for (size_t i = 0; i < Ln_; i++) + { + w_[i] = 1.0/w[i]; + } + return; +} + void gctl::common_gradient::init(size_t Ln, size_t Mn) { Ln_ = Ln; @@ -61,6 +73,7 @@ void gctl::common_gradient::init(size_t Ln, size_t Mn) G_.resize(Ln_, Mn_); t_.resize(Ln_); gm_.resize(Ln_); + w_.resize(Ln_, 1.0); x_.resize(Ln_, 1.0); filled_.resize(Ln_, false); return; diff --git a/lib/optimization/cmn_grad.h b/lib/optimization/cmn_grad.h index 5fcb21e..568b6b6 100644 --- a/lib/optimization/cmn_grad.h +++ b/lib/optimization/cmn_grad.h @@ -56,6 +56,11 @@ namespace gctl */ void set_solver(const lcg_para ¶); + /** + * @brief Set the weights for the loss functions + */ + void set_weights(const _1d_array &w); + /** * @brief Initialize the common_gradient object * @@ -84,7 +89,7 @@ namespace gctl size_t Ln_, Mn_; // Ln_: loss_func number,Mn_: model number _2d_matrix G_; _1d_array B_, g_, t_, x_; - _1d_array gm_; + _1d_array gm_, w_; array filled_; }; };