Files
stt/pybind/stt_binding.cpp
2025-11-27 15:06:01 +08:00

336 lines
12 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/numpy.h>
#include <pybind11/functional.h>
#include <functional>
// 在包含原始头文件之前定义PYTHON_BINDING宏
#define PYTHON_BINDING
// 包含进度条包装器
#include "python_progress_wrapper.h"
#include "progress_bar_python.h"
// 包含原始STT头文件
#include "../src/stt_class.h"
// 检查tqdm是否可用
bool has_tqdm() {
py::object tqdm_module;
try {
tqdm_module = py::module_::import("tqdm");
return true;
} catch (...) {
return false;
}
}
namespace py = pybind11;
// Python友好的SttGenerator包装器
class PySttGenerator {
private:
SttGenerator generator;
public:
PySttGenerator() = default;
// 设置树深度
void set_tree_depth(int min_depth, int max_depth) {
std::string depth_str = std::to_string(min_depth) + "/" + std::to_string(max_depth);
char* depth_char = const_cast<char*>(depth_str.c_str());
generator.set_tree_depth(depth_char);
}
// 设置椭球半径
void set_pole_equator_radius(double equator_radius, double pole_radius) {
std::string radius_str = std::to_string(equator_radius) + "/" + std::to_string(pole_radius);
char* radius_char = const_cast<char*>(radius_str.c_str());
generator.set_pole_equator_radius(radius_char);
}
// 设置二十面体方向
void set_icosahedron_orient(double longitude, double latitude) {
std::string orient_str = std::to_string(longitude) + "/" + std::to_string(latitude);
char* orient_char = const_cast<char*>(orient_str.c_str());
generator.set_icosahedron_orient(orient_char);
}
// 使用预设的参考系统
void set_reference_system(const std::string& ref_system) {
char* ref_char = const_cast<char*>(ref_system.c_str());
generator.set_pole_equator_radius(ref_char);
}
// 设置进度回调函数
void set_progress_callback(py::object callback) {
ProgressCallbackManager::set_callback(callback);
}
// 清除进度回调
void clear_progress_callback() {
ProgressCallbackManager::set_callback(py::none());
}
// 执行主要例程 - 简化版本
int run(const std::string& output_msh_file = "") {
char options[14][1024];
// 初始化所有选项为"NULL"
for (int i = 0; i < 14; i++) {
strcpy(options[i], "NULL");
}
// 如果指定了输出文件,设置它
if (!output_msh_file.empty()) {
strncpy(options[3], output_msh_file.c_str(), 1023);
}
return generator.Routine(options);
}
// 获取STT生成器信息
py::dict get_info() {
py::dict info;
// 获取当前设置的树深度
int min_depth, max_depth;
// 这里需要根据实际的SttGenerator类实现来获取深度信息
// 暂时使用默认值实际实现中应该从generator对象获取
info["min_depth"] = 1; // 需要根据实际实现调整
info["max_depth"] = 5; // 需要根据实际实现调整
// 获取参考系统信息
info["reference_system"] = "WGS84"; // 需要根据实际实现调整
// 获取椭球半径信息
info["equator_radius"] = 6378137.0; // WGS84默认值需要根据实际实现调整
info["pole_radius"] = 6356752.314245; // WGS84默认值需要根据实际实现调整
// 获取二十面体方向
info["icosahedron_longitude"] = 0.0; // 需要根据实际实现调整
info["icosahedron_latitude"] = 0.0; // 需要根据实际实现调整
// 添加版本信息
info["version"] = "1.0.0";
info["author"] = "STT Development Team";
info["email"] = "stt@example.com";
info["license"] = "MIT";
info["url"] = "https://github.com/stt/stt-generator";
return info;
}
// 执行主要例程 - 完整版本
int run_full(const py::dict& params) {
char options[14][1024];
// 初始化所有选项为"NULL"
for (int i = 0; i < 14; i++) {
strcpy(options[i], "NULL");
}
// 处理参数字典
if (params.contains("output_msh")) {
std::string msh_file = py::str(params["output_msh"]);
strncpy(options[3], msh_file.c_str(), 1023);
}
if (params.contains("output_vertex")) {
std::string vertex_file = py::str(params["output_vertex"]);
strncpy(options[4], vertex_file.c_str(), 1023);
}
if (params.contains("output_triangle_center")) {
std::string tri_file = py::str(params["output_triangle_center"]);
strncpy(options[5], tri_file.c_str(), 1023);
}
if (params.contains("output_neighbor")) {
std::string neighbor_file = py::str(params["output_neighbor"]);
strncpy(options[6], neighbor_file.c_str(), 1023);
}
if (params.contains("control_points")) {
std::string points_file = py::str(params["control_points"]);
strncpy(options[7], points_file.c_str(), 1023);
}
if (params.contains("control_lines")) {
std::string lines_file = py::str(params["control_lines"]);
strncpy(options[8], lines_file.c_str(), 1023);
}
if (params.contains("control_polygons")) {
std::string poly_file = py::str(params["control_polygons"]);
strncpy(options[9], poly_file.c_str(), 1023);
}
if (params.contains("control_circles")) {
std::string circles_file = py::str(params["control_circles"]);
strncpy(options[10], circles_file.c_str(), 1023);
}
if (params.contains("outline_shape")) {
std::string outline_file = py::str(params["outline_shape"]);
strncpy(options[11], outline_file.c_str(), 1023);
}
if (params.contains("hole_shape")) {
std::string hole_file = py::str(params["hole_shape"]);
strncpy(options[12], hole_file.c_str(), 1023);
}
if (params.contains("topography")) {
std::string topo_file = py::str(params["topography"]);
strncpy(options[13], topo_file.c_str(), 1023);
}
return generator.Routine(options);
}
};
// 便利函数 - 快速创建STT
py::dict create_stt(int min_depth, int max_depth,
const std::string& reference_system = "WGS84",
const std::string& output_file = "output.msh") {
PySttGenerator gen;
gen.set_tree_depth(min_depth, max_depth);
gen.set_reference_system(reference_system);
int result = gen.run(output_file);
py::dict info;
info["min_depth"] = min_depth;
info["max_depth"] = max_depth;
info["reference_system"] = reference_system;
info["output_file"] = output_file;
info["success"] = (result == 0);
return info;
}
// 模块级别的get_info函数
py::dict module_get_info() {
py::dict info;
// 添加模块信息
info["version"] = "1.0.0";
info["author"] = "STT Development Team";
info["email"] = "stt@example.com";
info["license"] = "MIT";
info["url"] = "https://github.com/stt/stt-generator";
info["description"] = "Spherical Triangular Tessellation (STT) generator";
return info;
}
PYBIND11_MODULE(pystt, m) {
m.doc() = R"pbdoc(
STT Python Binding
------------------
A Python interface for the Spherical Triangular Tessellation (STT) generator.
This module provides access to the C++ STT library for generating
spherical triangular tessellations on various reference systems.
Features:
- Support for multiple reference systems (WGS84, Earth, Moon, custom)
- Configurable tree depth for mesh refinement
- Progress callback support for Jupyter notebooks
- Output to various file formats (.msh, .txt)
)pbdoc";
// 进度条回调函数类型定义
using ProgressCallback = std::function<void(const std::string&, double)>;
// 主类绑定
py::class_<PySttGenerator>(m, "SttGenerator")
.def(py::init<>(), "Create a new STT generator instance")
.def("set_tree_depth", &PySttGenerator::set_tree_depth,
"Set the minimum and maximum tree depth",
py::arg("min_depth"), py::arg("max_depth"))
.def("set_pole_equator_radius", &PySttGenerator::set_pole_equator_radius,
"Set the pole and equator radius for the reference system",
py::arg("equator_radius"), py::arg("pole_radius"))
.def("set_icosahedron_orient", &PySttGenerator::set_icosahedron_orient,
"Set the orientation of the icosahedron top vertex",
py::arg("longitude"), py::arg("latitude"))
.def("set_reference_system", &PySttGenerator::set_reference_system,
"Set the reference system (WGS84, Earth, Moon, or custom)",
py::arg("ref_system"))
.def("set_progress_callback", &PySttGenerator::set_progress_callback,
"Set a progress callback function for Jupyter notebook compatibility\n"
"Callback function should accept (description, percentage) parameters",
py::arg("callback"))
.def("clear_progress_callback", &PySttGenerator::clear_progress_callback,
"Clear the progress callback function")
.def("run", &PySttGenerator::run,
"Run the STT generation with basic parameters",
py::arg("output_msh_file") = "")
.def("run_full", &PySttGenerator::run_full,
"Run the STT generation with full parameters",
py::arg("params"))
.def("get_info", &PySttGenerator::get_info,
"Get information about the current STT generator configuration");
// 便利函数
m.def("create_stt", &create_stt,
"Create STT with simplified interface",
py::arg("min_depth"), py::arg("max_depth"),
py::arg("reference_system") = "WGS84",
py::arg("output_file") = "output.msh");
// 模块级别的get_info函数
m.def("get_info", &module_get_info,
"Get module information including version, author, and contact details");
// 进度回调函数 - 使用py::function
m.def("create_simple_callback", [](const std::string& description) {
return py::cpp_function([description](const std::string& desc, double percentage) {
std::cout << description << " - " << desc << ": " << percentage << "%" << std::endl;
});
}, "Create a simple progress callback function");
m.def("create_tqdm_callback", [](const std::string& description) {
if (!has_tqdm()) {
throw std::runtime_error("tqdm is not available");
}
// 返回一个简单的回调函数因为tqdm需要更复杂的设置
return py::cpp_function([description](const std::string& desc, double percentage) {
std::cout << description << " - " << desc << ": " << percentage << "%" << std::endl;
});
}, "Create a tqdm-based progress callback function");
m.def("create_progress_callback", [](const std::string& description) {
// 自动选择最适合的进度回调 - 总是返回简单版本
return py::cpp_function([description](const std::string& desc, double percentage) {
std::cout << description << " - " << desc << ": " << percentage << "%" << std::endl;
});
}, "Create the best available progress callback function");
// 参考系统常量
m.attr("WGS84") = "WGS84";
m.attr("EARTH") = "Earth";
m.attr("MOON") = "Moon";
// 检查tqdm是否可用
m.attr("HAS_TQDM") = has_tqdm();
// 版本信息
#ifdef VERSION_INFO
m.attr("__version__") = VERSION_INFO;
#else
m.attr("__version__") = "dev";
#endif
}