tmp
This commit is contained in:
parent
adf998fb8b
commit
bc8b09f693
@ -42,6 +42,7 @@ gctl::common_gradient::~common_gradient(){}
|
|||||||
void gctl::common_gradient::LCG_Ax(const array<double> &x, array<double> &ax)
|
void gctl::common_gradient::LCG_Ax(const array<double> &x, array<double> &ax)
|
||||||
{
|
{
|
||||||
matvec(t_, G_, x, NoTrans);
|
matvec(t_, G_, x, NoTrans);
|
||||||
|
vecmul(t_, t_, w_);
|
||||||
matvec(ax, G_, t_, Trans);
|
matvec(ax, G_, t_, Trans);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -52,6 +53,17 @@ void gctl::common_gradient::set_solver(const lcg_para ¶)
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void gctl::common_gradient::set_weights(const _1d_array &w)
|
||||||
|
{
|
||||||
|
if (w.size()!= Ln_) throw std::runtime_error("[gctl::common_gradient] Invalid array size.");
|
||||||
|
|
||||||
|
for (size_t i = 0; i < Ln_; i++)
|
||||||
|
{
|
||||||
|
w_[i] = 1.0/w[i];
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
void gctl::common_gradient::init(size_t Ln, size_t Mn)
|
void gctl::common_gradient::init(size_t Ln, size_t Mn)
|
||||||
{
|
{
|
||||||
Ln_ = Ln;
|
Ln_ = Ln;
|
||||||
@ -61,6 +73,7 @@ void gctl::common_gradient::init(size_t Ln, size_t Mn)
|
|||||||
G_.resize(Ln_, Mn_);
|
G_.resize(Ln_, Mn_);
|
||||||
t_.resize(Ln_);
|
t_.resize(Ln_);
|
||||||
gm_.resize(Ln_);
|
gm_.resize(Ln_);
|
||||||
|
w_.resize(Ln_, 1.0);
|
||||||
x_.resize(Ln_, 1.0);
|
x_.resize(Ln_, 1.0);
|
||||||
filled_.resize(Ln_, false);
|
filled_.resize(Ln_, false);
|
||||||
return;
|
return;
|
||||||
|
@ -56,6 +56,11 @@ namespace gctl
|
|||||||
*/
|
*/
|
||||||
void set_solver(const lcg_para ¶);
|
void set_solver(const lcg_para ¶);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Set the weights for the loss functions
|
||||||
|
*/
|
||||||
|
void set_weights(const _1d_array &w);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Initialize the common_gradient object
|
* @brief Initialize the common_gradient object
|
||||||
*
|
*
|
||||||
@ -84,7 +89,7 @@ namespace gctl
|
|||||||
size_t Ln_, Mn_; // Ln_: loss_func number,Mn_: model number
|
size_t Ln_, Mn_; // Ln_: loss_func number,Mn_: model number
|
||||||
_2d_matrix G_;
|
_2d_matrix G_;
|
||||||
_1d_array B_, g_, t_, x_;
|
_1d_array B_, g_, t_, x_;
|
||||||
_1d_array gm_;
|
_1d_array gm_, w_;
|
||||||
array<bool> filled_;
|
array<bool> filled_;
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
Loading…
Reference in New Issue
Block a user