tmp update

This commit is contained in:
张壹 2024-10-09 14:30:46 +08:00
parent e13905b970
commit c2832c3104
3 changed files with 18 additions and 15 deletions

View File

@ -73,7 +73,7 @@ int main(int argc, char const *argv[]) try
array<double> x(201), y(301); array<double> x(201), y(301);
x.sequence(-1.0, 0.01); x.sequence(-1.0, 0.01);
y.sequence(0.0, 0.01); y.sequence(0.0, 0.01);
kde2d k(0.1, x, y); kde2d k(0.1, 0.1, x, y);
array<double> a(10000), b(10000); array<double> a(10000), b(10000);
a.random_float(0, 0.2, RdNormal, 0); a.random_float(0, 0.2, RdNormal, 0);

View File

@ -188,17 +188,18 @@ gctl::kde2d::kde2d(){}
gctl::kde2d::~kde2d(){} gctl::kde2d::~kde2d(){}
gctl::kde2d::kde2d(double h, const array<double> &x, const array<double> &y) gctl::kde2d::kde2d(double hx, double hy, const array<double> &x, const array<double> &y)
{ {
init(h, x, y); init(hx, hy, x, y);
} }
void gctl::kde2d::init(double h, const array<double> &x, const array<double> &y) void gctl::kde2d::init(double hx, double hy, const array<double> &x, const array<double> &y)
{ {
if (h <= 0) throw std::runtime_error("[gctl::kde2d] Invalid averaging width."); if (hx <= 0 || hy <= 0) throw std::runtime_error("[gctl::kde2d] Invalid averaging width.");
if (x.size() < 2 || y.size() < 2) throw std::runtime_error("[gctl::kde2d] Invalid sample size."); if (x.size() < 2 || y.size() < 2) throw std::runtime_error("[gctl::kde2d] Invalid sample size.");
h_ = h; hx_ = hx;
hy_ = hy;
xs_ = x.size(); xs_ = x.size();
ys_ = y.size(); ys_ = y.size();
x_ = x; x_ = x;
@ -224,10 +225,10 @@ void gctl::kde2d::get_distribution(const array<double> &mx,
out = 0.0; out = 0.0;
for (size_t k = 0; k < ms; k++) for (size_t k = 0; k < ms; k++)
{ {
out += gaussian_kernel((x_[j] - mx[k])/h_, (y_[i] - my[k])/h_); out += gaussian_kernel((x_[j] - mx[k])/hx_, (y_[i] - my[k])/hy_);
} }
dxy[i*xs_ + j] = out/(h_*h_*ms); dxy[i*xs_ + j] = out/(hx_*hy_*ms);
} }
} }
} }
@ -250,7 +251,7 @@ void gctl::kde2d::get_gradient_x_at(size_t mx_id, size_t my_id,
{ {
for (size_t j = 0; j < xs_; j++) for (size_t j = 0; j < xs_; j++)
{ {
dmx[i*xs_ + j] = ((x_[j] - mx[mx_id])/h_)*gaussian_kernel((x_[j] - mx[mx_id])/h_, (y_[i] - my[my_id])/h_)/(h_*h_*h_*ms); dmx[i*xs_ + j] = ((x_[j] - mx[mx_id])/hx_)*gaussian_kernel((x_[j] - mx[mx_id])/hx_, (y_[i] - my[my_id])/hy_)/(hx_*hx_*hy_*ms);
} }
} }
} }
@ -272,7 +273,7 @@ void gctl::kde2d::get_gradient_y_at(size_t mx_id, size_t my_id,
{ {
for (size_t j = 0; j < xs_; j++) for (size_t j = 0; j < xs_; j++)
{ {
dmy[i*xs_ + j] = ((y_[i] - my[my_id])/h_)*gaussian_kernel((x_[j] - mx[mx_id])/h_, (y_[i] - my[my_id])/h_)/(h_*h_*h_*ms); dmy[i*xs_ + j] = ((y_[i] - my[my_id])/hy_)*gaussian_kernel((x_[j] - mx[mx_id])/hx_, (y_[i] - my[my_id])/hy_)/(hy_*hy_*hx_*ms);
} }
} }
} }

View File

@ -112,20 +112,22 @@ namespace gctl
/** /**
* @brief Construct a new kde2d object * @brief Construct a new kde2d object
* *
* @param h * @param hx
* @param hy
* @param x * @param x
* @param y * @param y
*/ */
kde2d(double h, const array<double> &x, const array<double> &y); kde2d(double hx, double hy, const array<double> &x, const array<double> &y);
/** /**
* @brief * @brief
* *
* @param h * @param hx
* @param hy
* @param x * @param x
* @param y * @param y
*/ */
void init(double h, const array<double> &x, const array<double> &y); void init(double hx, double hy, const array<double> &x, const array<double> &y);
/** /**
* @brief Get the distribution object * @brief Get the distribution object
@ -168,7 +170,7 @@ namespace gctl
private: private:
size_t xs_, ys_; size_t xs_, ys_;
double h_; double hx_, hy_;
array<double> x_, y_; array<double> x_, y_;
double gaussian_kernel(double x, double y); double gaussian_kernel(double x, double y);