tmp update

This commit is contained in:
张壹 2024-10-07 12:54:38 +08:00
parent b0ad80f0b9
commit 9aedd1b8e8
2 changed files with 22 additions and 5 deletions

View File

@ -30,25 +30,33 @@
gctl::loss_func::loss_func() gctl::loss_func::loss_func()
{ {
init_ = false; init_ = false;
eps_ = 1e-8;
tnum_ = 0; tnum_ = 0;
ntype_ = L2; ntype_ = L2;
} }
gctl::loss_func::loss_func(const array<double> &tar, norm_type_e n_type) gctl::loss_func::loss_func(const array<double> &tar, norm_type_e n_type, double p, double eps)
{ {
init(tar, n_type); init(tar, n_type);
} }
gctl::loss_func::~loss_func(){} gctl::loss_func::~loss_func(){}
void gctl::loss_func::init(const array<double> &tar, norm_type_e n_type) void gctl::loss_func::init(const array<double> &tar, norm_type_e n_type, double p, double eps)
{ {
if (p < 1) throw std::runtime_error("[gctl::loss_func] Invalid power number.");
if (eps <= 0) throw std::runtime_error("[gctl::loss_func] Invalid epsilon value.");
init_ = true;
tnum_ = tar.size(); tnum_ = tar.size();
diff_.resize(tnum_); diff_.resize(tnum_);
us_.resize(tnum_, 1.0); us_.resize(tnum_, 1.0);
tars_ = tar; tars_ = tar;
ntype_ = n_type; ntype_ = n_type;
init_ = true; eps_ = eps;
p_ = p;
return; return;
} }
@ -97,6 +105,14 @@ double gctl::loss_func::evaluate(const array<double> &x, array<double> &g)
g[i] = 2.0*diff_[i]/(us_[i]*tnum_); g[i] = 2.0*diff_[i]/(us_[i]*tnum_);
} }
} }
else if (ntype_ == Lp)
{
for (size_t i = 0; i < tnum_; i++)
{
loss += pow(diff_[i]*diff_[i] + eps_*eps_, 0.5*p_);
g[i] = p_*pow(diff_[i]*diff_[i] + eps_*eps_, 0.5*p_ - 1)*diff_[i]/(us_[i]*tnum_);
}
}
else throw std::runtime_error("[gctl::loss_func] Invalid measurement type."); else throw std::runtime_error("[gctl::loss_func] Invalid measurement type.");
return loss/tnum_; return loss/tnum_;

View File

@ -37,16 +37,17 @@ namespace gctl
{ {
public: public:
loss_func(); loss_func();
loss_func(const array<double> &tar, norm_type_e n_type); loss_func(const array<double> &tar, norm_type_e n_type, double p = 2.0, double eps = 1e-16);
virtual ~loss_func(); virtual ~loss_func();
void init(const array<double> &tar, norm_type_e n_type); void init(const array<double> &tar, norm_type_e n_type, double p = 2.0, double eps = 1e-16);
void set_uncertainty(double uncer); void set_uncertainty(double uncer);
void set_uncertainty(const array<double> &uncer); void set_uncertainty(const array<double> &uncer);
double evaluate(const array<double> &x, array<double> &g); double evaluate(const array<double> &x, array<double> &g);
private: private:
bool init_; bool init_;
double eps_, p_;
unsigned int tnum_; unsigned int tnum_;
norm_type_e ntype_; norm_type_e ntype_;
array<double> tars_, diff_; array<double> tars_, diff_;