tmp update

This commit is contained in:
张壹 2024-10-07 12:54:38 +08:00
parent b0ad80f0b9
commit 9aedd1b8e8
2 changed files with 22 additions and 5 deletions

View File

@ -30,25 +30,33 @@
gctl::loss_func::loss_func()
{
init_ = false;
eps_ = 1e-8;
tnum_ = 0;
ntype_ = L2;
}
gctl::loss_func::loss_func(const array<double> &tar, norm_type_e n_type)
gctl::loss_func::loss_func(const array<double> &tar, norm_type_e n_type, double p, double eps)
{
init(tar, n_type);
}
gctl::loss_func::~loss_func(){}
void gctl::loss_func::init(const array<double> &tar, norm_type_e n_type)
void gctl::loss_func::init(const array<double> &tar, norm_type_e n_type, double p, double eps)
{
if (p < 1) throw std::runtime_error("[gctl::loss_func] Invalid power number.");
if (eps <= 0) throw std::runtime_error("[gctl::loss_func] Invalid epsilon value.");
init_ = true;
tnum_ = tar.size();
diff_.resize(tnum_);
us_.resize(tnum_, 1.0);
tars_ = tar;
ntype_ = n_type;
init_ = true;
eps_ = eps;
p_ = p;
return;
}
@ -97,6 +105,14 @@ double gctl::loss_func::evaluate(const array<double> &x, array<double> &g)
g[i] = 2.0*diff_[i]/(us_[i]*tnum_);
}
}
else if (ntype_ == Lp)
{
for (size_t i = 0; i < tnum_; i++)
{
loss += pow(diff_[i]*diff_[i] + eps_*eps_, 0.5*p_);
g[i] = p_*pow(diff_[i]*diff_[i] + eps_*eps_, 0.5*p_ - 1)*diff_[i]/(us_[i]*tnum_);
}
}
else throw std::runtime_error("[gctl::loss_func] Invalid measurement type.");
return loss/tnum_;

View File

@ -37,16 +37,17 @@ namespace gctl
{
public:
loss_func();
loss_func(const array<double> &tar, norm_type_e n_type);
loss_func(const array<double> &tar, norm_type_e n_type, double p = 2.0, double eps = 1e-16);
virtual ~loss_func();
void init(const array<double> &tar, norm_type_e n_type);
void init(const array<double> &tar, norm_type_e n_type, double p = 2.0, double eps = 1e-16);
void set_uncertainty(double uncer);
void set_uncertainty(const array<double> &uncer);
double evaluate(const array<double> &x, array<double> &g);
private:
bool init_;
double eps_, p_;
unsigned int tnum_;
norm_type_e ntype_;
array<double> tars_, diff_;