This commit is contained in:
张壹 2024-09-30 14:15:16 +08:00
parent b0b98ff47e
commit 4e9f933f8f
4 changed files with 122 additions and 126 deletions

View File

@ -21,7 +21,7 @@ add_example(fft_filter_ex OFF)
add_example(windowfunc_ex OFF)
add_example(legendre_ex OFF)
add_example(refellipsoid_ex OFF)
add_example(kde_ex OFF)
add_example(kde_ex ON)
add_example(meshio_ex OFF)
add_example(autodiff_ex OFF)
add_example(multinary_ex OFF)

View File

@ -31,27 +31,45 @@
using namespace gctl;
int main(int argc, char const *argv[]) try
{
/*
{
array<double> x(201);
x.sequence(-1.0, 0.01);
kde k(0.02, x);
array<double> d;
array<double> a(100000);
a.random(0, 0.2, RdNormal, 0);
kde k(0.02, a);
array<double> x;
linespace(-1.0, 1.0, 201, x);
a.random_float(0.2, 0.2, RdNormal, 0);
k.get_distribution(a, d);
gaussian_para1d g1(0, 0.2);
array<double> g(201);
for (size_t i = 0; i < x.size(); i++)
{
std::cout << x[i] << " "
<< gaussian_dist1d(x[i], g1) << " "
<< k.get_density_at(x[i]) << " "
<< k.get_density_at(x[i], KDE_Epanechnikov) << " "
<< k.get_density_at(x[i], KDE_Rectangular) << " "
<< k.get_density_at(x[i], KDE_Triangular) << "\n";
g[i] = gaussian_dist1d(x[i], g1);
}
*/
array<double> dm(201);
array<double> am(100000, 0.0);
/*
for (size_t i = 0; i < am.size(); i++)
{
k.get_gradient_at(i, a, dm);
for (size_t j = 0; j < x.size(); j++)
{
am[i] += 2.0*(d[j] - g[j])*dm[j];
}
std::cout << a[i] << " " << am[i] << "\n";
}
*/
for (size_t i = 0; i < x.size(); i++)
{
std::cout << x[i] << " " << d[i] - g[i] << "\n";
}
/*
array<double> a(100000), b(100000);
a.random(0, 0.2, RdNormal, 0);
b.random(1, 0.2, RdNormal, 0);
@ -78,6 +96,7 @@ int main(int argc, char const *argv[]) try
}
std::clog << "sum = " << sum << "\n";
*/
return 0;
}
catch(std::exception &e)

View File

@ -43,149 +43,106 @@ void gctl::kde::init(double h, const array<double> &x)
h_ = h;
x_ = x;
xs_ = x.size();
return;
}
double gctl::kde::get_density_at(double x, kde_kernel_e k_type)
void gctl::kde::get_distribution(const array<double> &m, array<double> &d,
kde_kernel_e k_type)
{
double out = 0;
double out;
int ms = m.size();
d.resize(xs_);
if (k_type == KDE_Gaussian)
{
for (size_t i = 0; i < x_.size(); i++)
for (size_t i = 0; i < xs_; i++)
{
out += gaussian_kernel((x - x_[i])/h_);
out = 0;
for (size_t j = 0; j < ms; j++)
{
out += gaussian_kernel((x_[i] - m[j])/h_);
}
d[i] = out/(h_*ms);
}
}
else if (k_type == KDE_Epanechnikov)
{
for (size_t i = 0; i < x_.size(); i++)
for (size_t i = 0; i < xs_; i++)
{
out += epanechnikov_kernel((x - x_[i])/h_);
out = 0;
for (size_t j = 0; j < ms; j++)
{
out += epanechnikov_kernel((x_[i] - m[j])/h_);
}
d[i] = out/(h_*ms);
}
}
else if (k_type == KDE_Rectangular)
{
for (size_t i = 0; i < x_.size(); i++)
for (size_t i = 0; i < xs_; i++)
{
out += rectangular_kernel((x - x_[i])/h_);
out = 0;
for (size_t j = 0; j < ms; j++)
{
out += rectangular_kernel((x_[i] - m[j])/h_);
}
d[i] = out/(h_*ms);
}
}
else
{
for (size_t i = 0; i < x_.size(); i++)
for (size_t i = 0; i < xs_; i++)
{
out += triangular_kernel((x - x_[i])/h_);
out = 0;
for (size_t j = 0; j < ms; j++)
{
out += triangular_kernel((x_[i] - m[j])/h_);
}
d[i] = out/(h_*ms);
}
}
return out/(h_*x_.size());
return;
}
double gctl::kde::get_kernel_density_at(size_t k_id, double x, kde_kernel_e k_type)
void gctl::kde::get_gradient_at(size_t m_id, const array<double> &m,
array<double> &dm, kde_kernel_e k_type)
{
if (k_id >= x_.size()) throw std::runtime_error("[gctl::kde::get_kernel_density_at(...)] Invalid kernel index.");
dm.resize(xs_);
int ms = m.size();
double out;
if (k_type == KDE_Gaussian) out = gaussian_kernel((x - x_[k_id])/h_);
else if (k_type == KDE_Epanechnikov) out = epanechnikov_kernel((x - x_[k_id])/h_);
else if (k_type == KDE_Rectangular) out = rectangular_kernel((x - x_[k_id])/h_);
else out = triangular_kernel((x - x_[k_id])/h_);
return out/h_;
}
double gctl::kde::get_gradient_at(double x, kde_kernel_e k_type)
{
double out = 0;
if (k_type == KDE_Gaussian)
{
for (size_t i = 0; i < x_.size(); i++)
for (size_t i = 0; i < xs_; i++)
{
out += ((x - x_[i])/h_)*gaussian_kernel((x - x_[i])/h_);
dm[i] = ((x_[i] - m[m_id])/h_)*gaussian_kernel((x_[i] - m[m_id])/h_)/(h_*h_*ms);
}
}
else if (k_type == KDE_Epanechnikov)
{
for (size_t i = 0; i < x_.size(); i++)
for (size_t i = 0; i < xs_; i++)
{
out += epanechnikov_kernel((x - x_[i])/h_, true);
dm[i] = -1.0*epanechnikov_kernel((x_[i] - m[m_id])/h_, true)/(h_*h_*ms);
}
}
else if (k_type == KDE_Rectangular)
{
for (size_t i = 0; i < x_.size(); i++)
for (size_t i = 0; i < xs_; i++)
{
out += rectangular_kernel((x - x_[i])/h_, true);
dm[i] = -1.0*rectangular_kernel((x_[i] - m[m_id])/h_, true)/(h_*h_*ms);
}
}
else
{
for (size_t i = 0; i < x_.size(); i++)
for (size_t i = 0; i < xs_; i++)
{
out += triangular_kernel((x - x_[i])/h_, true);
dm[i] = -1.0*triangular_kernel((x_[i] - m[m_id])/h_, true)/(h_*h_*ms);
}
}
return -1.0*out/(h_*h_*x_.size());
}
double gctl::kde::get_kernel_gradient_at(size_t k_id, double x, kde_kernel_e k_type)
{
if (k_id >= x_.size()) throw std::runtime_error("[gctl::kde::get_kernel_gradient_at(...)] Invalid kernel index.");
double out;
if (k_type == KDE_Gaussian) out = ((x - x_[k_id])/h_)*gaussian_kernel((x - x_[k_id])/h_);
else if (k_type == KDE_Epanechnikov) out = epanechnikov_kernel((x - x_[k_id])/h_);
else if (k_type == KDE_Rectangular) out = rectangular_kernel((x - x_[k_id])/h_);
else out = triangular_kernel((x - x_[k_id])/h_);
return -1.0*out/(h_*h_);
}
void gctl::kde::get_distribution(const array<double> x, array<double> &d,
array<double> &dx, kde_kernel_e k_type, kde_norm_e n_type, double norm)
{
if (norm < 0.0) throw std::runtime_error("[GCTL Kernel Density Estimation] Invalid normalization value.");
size_t xnum = x.size();
d.resize(xnum);
dx.resize(xnum);
double s = 0.0;
if (n_type == KDE_MAX2ONE)
{
for (size_t i = 0; i < xnum; i++)
{
d[i] = get_density_at(x[i], k_type);
dx[i]= get_gradient_at(x[i], k_type);
s = std::max(s, d[i]);
}
}
else if (n_type == KDE_SUM2ONE)
{
for (size_t i = 0; i < xnum; i++)
{
d[i] = get_density_at(x[i], k_type);
dx[i]= get_gradient_at(x[i], k_type);
s += d[i];
}
}
else
{
for (size_t i = 0; i < xnum; i++)
{
d[i] = get_density_at(x[i], k_type);
dx[i]= get_gradient_at(x[i], k_type);
}
s = norm;
}
for (size_t i = 0; i < xnum; i++)
{
d[i] /= s;
dx[i]/= s;
}
return;
}
@ -199,7 +156,7 @@ double gctl::kde::epanechnikov_kernel(double x, bool gradient)
if (gradient)
{
if (fabs(x) >= 1) return 0;
else return 1.5*x;
else return -1.5*x;
}
if (fabs(x) >= 1) return 0;
@ -219,8 +176,8 @@ double gctl::kde::triangular_kernel(double x, bool gradient)
if (gradient)
{
if (fabs(x) >= 1) return 0;
else if (x >= 0) return 1.0;
else return -1.0;
else if (x >= 0) return -1.0;
else return 1.0;
}
if (fabs(x) >= 1) return 0;

View File

@ -53,27 +53,47 @@ namespace gctl
{
public:
kde();
kde(double h, const array<double> &x);
virtual ~kde();
void init(double h, const array<double> &x);
double get_density_at(double x, kde_kernel_e k_type = KDE_Gaussian);
/**
* @brief Construct a new kde object
*
* @param h
* @param s
*/
kde(double h, const array<double> &s);
/**
* @brief Get the probability density of a single kernel. Note the value is not normalized by the kernel number.
* @brief
*
* @param k_id kernel index
* @param x inquiring location
* @param k_type kernel type
* @return kernel value
* @param h
* @param x
*/
double get_kernel_density_at(size_t k_id, double x, kde_kernel_e k_type = KDE_Gaussian);
double get_gradient_at(double x, kde_kernel_e k_type = KDE_Gaussian);
double get_kernel_gradient_at(size_t k_id, double x, kde_kernel_e k_type = KDE_Gaussian);
void get_distribution(const array<double> x, array<double> &d, array<double> &dx,
kde_kernel_e k_type = KDE_Gaussian, kde_norm_e n_type = KDE_MAX2ONE, double norm = 1.0);
void init(double h, const array<double> &s);
/**
* @brief
*
* @param m
* @param d
* @param k_type
*/
void get_distribution(const array<double> &m, array<double> &d,
kde_kernel_e k_type = KDE_Gaussian);
/**
* @brief
*
* @param m_id
* @param m
* @param dm m_id的偏导数
* @param k_type
*/
void get_gradient_at(size_t m_id, const array<double> &m, array<double> &dm,
kde_kernel_e k_type = KDE_Gaussian);
private:
size_t xs_;
double h_;
array<double> x_;