tmp update
This commit is contained in:
parent
b0ad80f0b9
commit
9aedd1b8e8
@ -30,25 +30,33 @@
|
||||
gctl::loss_func::loss_func()
|
||||
{
|
||||
init_ = false;
|
||||
eps_ = 1e-8;
|
||||
tnum_ = 0;
|
||||
ntype_ = L2;
|
||||
}
|
||||
|
||||
gctl::loss_func::loss_func(const array<double> &tar, norm_type_e n_type)
|
||||
gctl::loss_func::loss_func(const array<double> &tar, norm_type_e n_type, double p, double eps)
|
||||
{
|
||||
init(tar, n_type);
|
||||
}
|
||||
|
||||
gctl::loss_func::~loss_func(){}
|
||||
|
||||
void gctl::loss_func::init(const array<double> &tar, norm_type_e n_type)
|
||||
void gctl::loss_func::init(const array<double> &tar, norm_type_e n_type, double p, double eps)
|
||||
{
|
||||
if (p < 1) throw std::runtime_error("[gctl::loss_func] Invalid power number.");
|
||||
if (eps <= 0) throw std::runtime_error("[gctl::loss_func] Invalid epsilon value.");
|
||||
|
||||
init_ = true;
|
||||
|
||||
tnum_ = tar.size();
|
||||
diff_.resize(tnum_);
|
||||
us_.resize(tnum_, 1.0);
|
||||
tars_ = tar;
|
||||
|
||||
ntype_ = n_type;
|
||||
init_ = true;
|
||||
eps_ = eps;
|
||||
p_ = p;
|
||||
return;
|
||||
}
|
||||
|
||||
@ -97,6 +105,14 @@ double gctl::loss_func::evaluate(const array<double> &x, array<double> &g)
|
||||
g[i] = 2.0*diff_[i]/(us_[i]*tnum_);
|
||||
}
|
||||
}
|
||||
else if (ntype_ == Lp)
|
||||
{
|
||||
for (size_t i = 0; i < tnum_; i++)
|
||||
{
|
||||
loss += pow(diff_[i]*diff_[i] + eps_*eps_, 0.5*p_);
|
||||
g[i] = p_*pow(diff_[i]*diff_[i] + eps_*eps_, 0.5*p_ - 1)*diff_[i]/(us_[i]*tnum_);
|
||||
}
|
||||
}
|
||||
else throw std::runtime_error("[gctl::loss_func] Invalid measurement type.");
|
||||
|
||||
return loss/tnum_;
|
||||
|
@ -37,16 +37,17 @@ namespace gctl
|
||||
{
|
||||
public:
|
||||
loss_func();
|
||||
loss_func(const array<double> &tar, norm_type_e n_type);
|
||||
loss_func(const array<double> &tar, norm_type_e n_type, double p = 2.0, double eps = 1e-16);
|
||||
virtual ~loss_func();
|
||||
|
||||
void init(const array<double> &tar, norm_type_e n_type);
|
||||
void init(const array<double> &tar, norm_type_e n_type, double p = 2.0, double eps = 1e-16);
|
||||
void set_uncertainty(double uncer);
|
||||
void set_uncertainty(const array<double> &uncer);
|
||||
double evaluate(const array<double> &x, array<double> &g);
|
||||
|
||||
private:
|
||||
bool init_;
|
||||
double eps_, p_;
|
||||
unsigned int tnum_;
|
||||
norm_type_e ntype_;
|
||||
array<double> tars_, diff_;
|
||||
|
Loading…
Reference in New Issue
Block a user