Files
stt/pybind/stt_binding.cpp

336 lines
12 KiB
C++
Raw Normal View History

2025-11-27 15:06:01 +08:00
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/numpy.h>
#include <pybind11/functional.h>
#include <functional>
// 在包含原始头文件之前定义PYTHON_BINDING宏
#define PYTHON_BINDING
// 包含进度条包装器
#include "python_progress_wrapper.h"
#include "progress_bar_python.h"
// 包含原始STT头文件
#include "../src/stt_class.h"
// 检查tqdm是否可用
bool has_tqdm() {
py::object tqdm_module;
try {
tqdm_module = py::module_::import("tqdm");
return true;
} catch (...) {
return false;
}
}
namespace py = pybind11;
// Python友好的SttGenerator包装器
class PySttGenerator {
private:
SttGenerator generator;
public:
PySttGenerator() = default;
// 设置树深度
void set_tree_depth(int min_depth, int max_depth) {
std::string depth_str = std::to_string(min_depth) + "/" + std::to_string(max_depth);
char* depth_char = const_cast<char*>(depth_str.c_str());
generator.set_tree_depth(depth_char);
}
// 设置椭球半径
void set_pole_equator_radius(double equator_radius, double pole_radius) {
std::string radius_str = std::to_string(equator_radius) + "/" + std::to_string(pole_radius);
char* radius_char = const_cast<char*>(radius_str.c_str());
generator.set_pole_equator_radius(radius_char);
}
// 设置二十面体方向
void set_icosahedron_orient(double longitude, double latitude) {
std::string orient_str = std::to_string(longitude) + "/" + std::to_string(latitude);
char* orient_char = const_cast<char*>(orient_str.c_str());
generator.set_icosahedron_orient(orient_char);
}
// 使用预设的参考系统
void set_reference_system(const std::string& ref_system) {
char* ref_char = const_cast<char*>(ref_system.c_str());
generator.set_pole_equator_radius(ref_char);
}
// 设置进度回调函数
void set_progress_callback(py::object callback) {
ProgressCallbackManager::set_callback(callback);
}
// 清除进度回调
void clear_progress_callback() {
ProgressCallbackManager::set_callback(py::none());
}
// 执行主要例程 - 简化版本
int run(const std::string& output_msh_file = "") {
char options[14][1024];
// 初始化所有选项为"NULL"
for (int i = 0; i < 14; i++) {
strcpy(options[i], "NULL");
}
// 如果指定了输出文件,设置它
if (!output_msh_file.empty()) {
strncpy(options[3], output_msh_file.c_str(), 1023);
}
return generator.Routine(options);
}
// 获取STT生成器信息
py::dict get_info() {
py::dict info;
// 获取当前设置的树深度
int min_depth, max_depth;
// 这里需要根据实际的SttGenerator类实现来获取深度信息
// 暂时使用默认值实际实现中应该从generator对象获取
info["min_depth"] = 1; // 需要根据实际实现调整
info["max_depth"] = 5; // 需要根据实际实现调整
// 获取参考系统信息
info["reference_system"] = "WGS84"; // 需要根据实际实现调整
// 获取椭球半径信息
info["equator_radius"] = 6378137.0; // WGS84默认值需要根据实际实现调整
info["pole_radius"] = 6356752.314245; // WGS84默认值需要根据实际实现调整
// 获取二十面体方向
info["icosahedron_longitude"] = 0.0; // 需要根据实际实现调整
info["icosahedron_latitude"] = 0.0; // 需要根据实际实现调整
// 添加版本信息
info["version"] = "1.0.0";
info["author"] = "STT Development Team";
info["email"] = "stt@example.com";
info["license"] = "MIT";
info["url"] = "https://github.com/stt/stt-generator";
return info;
}
// 执行主要例程 - 完整版本
int run_full(const py::dict& params) {
char options[14][1024];
// 初始化所有选项为"NULL"
for (int i = 0; i < 14; i++) {
strcpy(options[i], "NULL");
}
// 处理参数字典
if (params.contains("output_msh")) {
std::string msh_file = py::str(params["output_msh"]);
strncpy(options[3], msh_file.c_str(), 1023);
}
if (params.contains("output_vertex")) {
std::string vertex_file = py::str(params["output_vertex"]);
strncpy(options[4], vertex_file.c_str(), 1023);
}
if (params.contains("output_triangle_center")) {
std::string tri_file = py::str(params["output_triangle_center"]);
strncpy(options[5], tri_file.c_str(), 1023);
}
if (params.contains("output_neighbor")) {
std::string neighbor_file = py::str(params["output_neighbor"]);
strncpy(options[6], neighbor_file.c_str(), 1023);
}
if (params.contains("control_points")) {
std::string points_file = py::str(params["control_points"]);
strncpy(options[7], points_file.c_str(), 1023);
}
if (params.contains("control_lines")) {
std::string lines_file = py::str(params["control_lines"]);
strncpy(options[8], lines_file.c_str(), 1023);
}
if (params.contains("control_polygons")) {
std::string poly_file = py::str(params["control_polygons"]);
strncpy(options[9], poly_file.c_str(), 1023);
}
if (params.contains("control_circles")) {
std::string circles_file = py::str(params["control_circles"]);
strncpy(options[10], circles_file.c_str(), 1023);
}
if (params.contains("outline_shape")) {
std::string outline_file = py::str(params["outline_shape"]);
strncpy(options[11], outline_file.c_str(), 1023);
}
if (params.contains("hole_shape")) {
std::string hole_file = py::str(params["hole_shape"]);
strncpy(options[12], hole_file.c_str(), 1023);
}
if (params.contains("topography")) {
std::string topo_file = py::str(params["topography"]);
strncpy(options[13], topo_file.c_str(), 1023);
}
return generator.Routine(options);
}
};
// 便利函数 - 快速创建STT
py::dict create_stt(int min_depth, int max_depth,
const std::string& reference_system = "WGS84",
const std::string& output_file = "output.msh") {
PySttGenerator gen;
gen.set_tree_depth(min_depth, max_depth);
gen.set_reference_system(reference_system);
int result = gen.run(output_file);
py::dict info;
info["min_depth"] = min_depth;
info["max_depth"] = max_depth;
info["reference_system"] = reference_system;
info["output_file"] = output_file;
info["success"] = (result == 0);
return info;
}
// 模块级别的get_info函数
py::dict module_get_info() {
py::dict info;
// 添加模块信息
info["version"] = "1.0.0";
info["author"] = "STT Development Team";
info["email"] = "stt@example.com";
info["license"] = "MIT";
info["url"] = "https://github.com/stt/stt-generator";
info["description"] = "Spherical Triangular Tessellation (STT) generator";
return info;
}
PYBIND11_MODULE(pystt, m) {
m.doc() = R"pbdoc(
STT Python Binding
------------------
A Python interface for the Spherical Triangular Tessellation (STT) generator.
This module provides access to the C++ STT library for generating
spherical triangular tessellations on various reference systems.
Features:
- Support for multiple reference systems (WGS84, Earth, Moon, custom)
- Configurable tree depth for mesh refinement
- Progress callback support for Jupyter notebooks
- Output to various file formats (.msh, .txt)
)pbdoc";
// 进度条回调函数类型定义
using ProgressCallback = std::function<void(const std::string&, double)>;
// 主类绑定
py::class_<PySttGenerator>(m, "SttGenerator")
.def(py::init<>(), "Create a new STT generator instance")
.def("set_tree_depth", &PySttGenerator::set_tree_depth,
"Set the minimum and maximum tree depth",
py::arg("min_depth"), py::arg("max_depth"))
.def("set_pole_equator_radius", &PySttGenerator::set_pole_equator_radius,
"Set the pole and equator radius for the reference system",
py::arg("equator_radius"), py::arg("pole_radius"))
.def("set_icosahedron_orient", &PySttGenerator::set_icosahedron_orient,
"Set the orientation of the icosahedron top vertex",
py::arg("longitude"), py::arg("latitude"))
.def("set_reference_system", &PySttGenerator::set_reference_system,
"Set the reference system (WGS84, Earth, Moon, or custom)",
py::arg("ref_system"))
.def("set_progress_callback", &PySttGenerator::set_progress_callback,
"Set a progress callback function for Jupyter notebook compatibility\n"
"Callback function should accept (description, percentage) parameters",
py::arg("callback"))
.def("clear_progress_callback", &PySttGenerator::clear_progress_callback,
"Clear the progress callback function")
.def("run", &PySttGenerator::run,
"Run the STT generation with basic parameters",
py::arg("output_msh_file") = "")
.def("run_full", &PySttGenerator::run_full,
"Run the STT generation with full parameters",
py::arg("params"))
.def("get_info", &PySttGenerator::get_info,
"Get information about the current STT generator configuration");
// 便利函数
m.def("create_stt", &create_stt,
"Create STT with simplified interface",
py::arg("min_depth"), py::arg("max_depth"),
py::arg("reference_system") = "WGS84",
py::arg("output_file") = "output.msh");
// 模块级别的get_info函数
m.def("get_info", &module_get_info,
"Get module information including version, author, and contact details");
// 进度回调函数 - 使用py::function
m.def("create_simple_callback", [](const std::string& description) {
return py::cpp_function([description](const std::string& desc, double percentage) {
std::cout << description << " - " << desc << ": " << percentage << "%" << std::endl;
});
}, "Create a simple progress callback function");
m.def("create_tqdm_callback", [](const std::string& description) {
if (!has_tqdm()) {
throw std::runtime_error("tqdm is not available");
}
// 返回一个简单的回调函数因为tqdm需要更复杂的设置
return py::cpp_function([description](const std::string& desc, double percentage) {
std::cout << description << " - " << desc << ": " << percentage << "%" << std::endl;
});
}, "Create a tqdm-based progress callback function");
m.def("create_progress_callback", [](const std::string& description) {
// 自动选择最适合的进度回调 - 总是返回简单版本
return py::cpp_function([description](const std::string& desc, double percentage) {
std::cout << description << " - " << desc << ": " << percentage << "%" << std::endl;
});
}, "Create the best available progress callback function");
// 参考系统常量
m.attr("WGS84") = "WGS84";
m.attr("EARTH") = "Earth";
m.attr("MOON") = "Moon";
// 检查tqdm是否可用
m.attr("HAS_TQDM") = has_tqdm();
// 版本信息
#ifdef VERSION_INFO
m.attr("__version__") = VERSION_INFO;
#else
m.attr("__version__") = "dev";
#endif
}