tmp
This commit is contained in:
parent
adf998fb8b
commit
bc8b09f693
@ -42,6 +42,7 @@ gctl::common_gradient::~common_gradient(){}
|
||||
void gctl::common_gradient::LCG_Ax(const array<double> &x, array<double> &ax)
|
||||
{
|
||||
matvec(t_, G_, x, NoTrans);
|
||||
vecmul(t_, t_, w_);
|
||||
matvec(ax, G_, t_, Trans);
|
||||
return;
|
||||
}
|
||||
@ -52,6 +53,17 @@ void gctl::common_gradient::set_solver(const lcg_para ¶)
|
||||
return;
|
||||
}
|
||||
|
||||
void gctl::common_gradient::set_weights(const _1d_array &w)
|
||||
{
|
||||
if (w.size()!= Ln_) throw std::runtime_error("[gctl::common_gradient] Invalid array size.");
|
||||
|
||||
for (size_t i = 0; i < Ln_; i++)
|
||||
{
|
||||
w_[i] = 1.0/w[i];
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void gctl::common_gradient::init(size_t Ln, size_t Mn)
|
||||
{
|
||||
Ln_ = Ln;
|
||||
@ -61,6 +73,7 @@ void gctl::common_gradient::init(size_t Ln, size_t Mn)
|
||||
G_.resize(Ln_, Mn_);
|
||||
t_.resize(Ln_);
|
||||
gm_.resize(Ln_);
|
||||
w_.resize(Ln_, 1.0);
|
||||
x_.resize(Ln_, 1.0);
|
||||
filled_.resize(Ln_, false);
|
||||
return;
|
||||
|
@ -56,6 +56,11 @@ namespace gctl
|
||||
*/
|
||||
void set_solver(const lcg_para ¶);
|
||||
|
||||
/**
|
||||
* @brief Set the weights for the loss functions
|
||||
*/
|
||||
void set_weights(const _1d_array &w);
|
||||
|
||||
/**
|
||||
* @brief Initialize the common_gradient object
|
||||
*
|
||||
@ -84,7 +89,7 @@ namespace gctl
|
||||
size_t Ln_, Mn_; // Ln_: loss_func number,Mn_: model number
|
||||
_2d_matrix G_;
|
||||
_1d_array B_, g_, t_, x_;
|
||||
_1d_array gm_;
|
||||
_1d_array gm_, w_;
|
||||
array<bool> filled_;
|
||||
};
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user