This commit is contained in:
张壹 2025-02-22 20:17:42 +08:00
parent adf998fb8b
commit bc8b09f693
2 changed files with 19 additions and 1 deletions

View File

@ -42,6 +42,7 @@ gctl::common_gradient::~common_gradient(){}
void gctl::common_gradient::LCG_Ax(const array<double> &x, array<double> &ax)
{
matvec(t_, G_, x, NoTrans);
vecmul(t_, t_, w_);
matvec(ax, G_, t_, Trans);
return;
}
@ -52,6 +53,17 @@ void gctl::common_gradient::set_solver(const lcg_para &para)
return;
}
void gctl::common_gradient::set_weights(const _1d_array &w)
{
if (w.size()!= Ln_) throw std::runtime_error("[gctl::common_gradient] Invalid array size.");
for (size_t i = 0; i < Ln_; i++)
{
w_[i] = 1.0/w[i];
}
return;
}
void gctl::common_gradient::init(size_t Ln, size_t Mn)
{
Ln_ = Ln;
@ -61,6 +73,7 @@ void gctl::common_gradient::init(size_t Ln, size_t Mn)
G_.resize(Ln_, Mn_);
t_.resize(Ln_);
gm_.resize(Ln_);
w_.resize(Ln_, 1.0);
x_.resize(Ln_, 1.0);
filled_.resize(Ln_, false);
return;

View File

@ -56,6 +56,11 @@ namespace gctl
*/
void set_solver(const lcg_para &para);
/**
* @brief Set the weights for the loss functions
*/
void set_weights(const _1d_array &w);
/**
* @brief Initialize the common_gradient object
*
@ -84,7 +89,7 @@ namespace gctl
size_t Ln_, Mn_; // Ln_: loss_func numberMn_: model number
_2d_matrix G_;
_1d_array B_, g_, t_, x_;
_1d_array gm_;
_1d_array gm_, w_;
array<bool> filled_;
};
};