diff --git a/.gitignore b/.gitignore index d56cb93..16d4102 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ build/ -.DS_Store \ No newline at end of file +.DS_Store +.vscode/ \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index 5bc33a4..81cd7d7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,18 +1,89 @@ cmake_minimum_required(VERSION 3.15.2) + # 设置工程名称 project(stt VERSION 1.4.1 LANGUAGES CXX) +# 设置C++标准 +set(CMAKE_CXX_STANDARD 11) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +# 编译选项 set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3") set(EXECUTABLE_OUTPUT_PATH ${PROJECT_SOURCE_DIR}) message(STATUS "Platform: " ${CMAKE_HOST_SYSTEM_NAME}) -# CMake默认的安装路径 Windows下为C:/Program\ Files/${Project_Name} Linux/Unix下为/usr/local message(STATUS "Install prefix: " ${CMAKE_INSTALL_PREFIX}) -# CMake默认的变异类型为空 message(STATUS "Build type: " ${CMAKE_BUILD_TYPE}) +# 获取所有源文件 aux_source_directory(src STT_SRC) +# 创建可执行文件(保留原有功能) add_executable(stt ${STT_SRC}) set_target_properties(stt PROPERTIES CXX_STANDARD 11) -install(TARGETS stt RUNTIME DESTINATION sbin) \ No newline at end of file + +# 安装可执行文件 +install(TARGETS stt RUNTIME DESTINATION sbin) + +# Python绑定支持(可选) +option(BUILD_PYTHON_MODULE "Build Python module" OFF) + +if(BUILD_PYTHON_MODULE) + # 查找pybind11 + find_package(pybind11 REQUIRED) + + if(NOT pybind11_FOUND) + message(STATUS "pybind11 not found, trying to find it via Python") + # 尝试通过Python找到pybind11 + execute_process( + COMMAND ${Python_EXECUTABLE} -m pybind11 --includes + OUTPUT_VARIABLE PYBIND11_INCLUDES + OUTPUT_STRIP_TRAILING_WHITESPACE + ERROR_QUIET + ) + + if(NOT PYBIND11_INCLUDES) + message(FATAL_ERROR "pybind11 is required for Python module build") + endif() + endif() + + message(STATUS "Building Python module") + + # 创建Python模块 + pybind11_add_module(stt_python + pybind/stt_binding.cpp + ${STT_SRC} + ) + + # 设置模块属性 + set_target_properties(stt_python PROPERTIES + CXX_STANDARD 11 + OUTPUT_NAME "stt" + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/pybind" + ) + + # 添加包含目录 + target_include_directories(stt_python PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/src + ${CMAKE_CURRENT_SOURCE_DIR} + ) + + # 定义宏 + target_compile_definitions(stt_python PRIVATE + VERSION_INFO="${PROJECT_VERSION}" + ) + + # 链接库(如果需要) + # target_link_libraries(stt_python PRIVATE ...) + + message(STATUS "Python module will be built as: pybind/stt${PYTHON_MODULE_EXTENSION}") +endif() + +# 安装Python绑定文件(如果构建了Python模块) +if(BUILD_PYTHON_MODULE) + install(FILES + pybind/__init__.py + pybind/example_usage.py + DESTINATION ${Python_SITEARCH}/stt + ) +endif() \ No newline at end of file diff --git a/demo/example_jupyter.py b/demo/example_jupyter.py new file mode 100644 index 0000000..3ee24b1 --- /dev/null +++ b/demo/example_jupyter.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python3 +""" +STT Jupyter Notebook使用示例 + +这个脚本展示了如何在Jupyter notebook中使用STT的Python绑定, +包括进度条适配。 +""" + +import pystt as stt +import sys +import os + +def basic_jupyter_example(): + """Jupyter notebook基本使用示例""" + print("=== Jupyter Notebook基本示例 ===") + + # 创建生成器 + generator = stt.SttGenerator() + + # 设置参数 + generator.set_tree_depth(3, 6) + generator.set_reference_system("WGS84") + + # 创建简单的进度回调 + def simple_progress(description, percentage): + print(f"{description}: {percentage:.1f}%") + + generator.set_progress_callback(simple_progress) + + # 运行生成 + result = generator.run("jupyter_output.msh") + + print(f"生成结果: {'成功' if result == 0 else '失败'}") + print() + +def tqdm_jupyter_example(): + """使用tqdm的Jupyter示例""" + print("=== 使用tqdm的Jupyter示例 ===") + + if not stt.HAS_TQDM: + print("tqdm未安装,使用简单进度条") + progress_cb = stt.create_simple_callback("STT生成") + else: + print("使用tqdm进度条") + progress_cb = stt.create_tqdm_callback("STT生成") + + generator = stt.SttGenerator() + generator.set_progress_callback(progress_cb) + + generator.set_tree_depth(2, 5) + generator.set_reference_system("Earth") + + result = generator.run("tqdm_output.msh") + + print(f"生成结果: {'成功' if result == 0 else '失败'}") + print() + +def auto_progress_example(): + """自动选择最佳进度条""" + print("=== 自动选择进度条示例 ===") + + # 自动创建最适合的进度回调 + progress_cb = stt.create_progress_callback("自动进度") + + generator = stt.SttGenerator() + generator.set_progress_callback(progress_cb) + + generator.set_tree_depth(2, 4) + generator.set_reference_system("Moon") + + result = generator.run("auto_output.msh") + + print(f"生成结果: {'成功' if result == 0 else '失败'}") + print() + +def advanced_jupyter_example(): + """高级Jupyter示例""" + print("=== 高级Jupyter示例 ===") + + # 创建tqdm进度回调 + if stt.HAS_TQDM: + progress_cb = stt.TqdmProgressCallback("高级STT生成") + else: + progress_cb = stt.SimpleProgressCallback("高级STT生成") + + generator = stt.SttGenerator() + generator.set_progress_callback(progress_cb) + + # 设置自定义参考系统 + generator.set_pole_equator_radius(3396200.0, 3376200.0) # 火星 + + # 设置二十面体方向 + generator.set_icosahedron_orient(0.0, 90.0) + + # 使用完整参数 + params = { + "output_msh": "mars_grid.msh", + "output_vertex": "mars_vertices.txt", + "output_triangle_center": "mars_centers.txt" + } + + result = generator.run_full(params) + + print(f"火星网格生成: {'成功' if result == 0 else '失败'}") + print() + +def notebook_integration_demo(): + """Notebook集成演示""" + print("=== Notebook集成演示 ===") + + # 检查是否在notebook环境中 + try: + from IPython.display import display, HTML + in_notebook = True + except ImportError: + in_notebook = False + + if in_notebook: + print("检测到Jupyter环境,使用增强进度显示") + display(HTML("

STT生成进度

")) + + def notebook_progress(description, percentage): + display(HTML(f""" +
+ {description}: {percentage:.1f}% +
+
+
+
+ """)) + + generator = stt.SttGenerator() + generator.set_progress_callback(notebook_progress) + else: + print("普通Python环境,使用标准进度条") + progress_cb = stt.create_progress_callback("Notebook演示") + generator = stt.SttGenerator() + generator.set_progress_callback(progress_cb) + + generator.set_tree_depth(2, 4) + generator.set_reference_system("WGS84") + + result = generator.run("notebook_output.msh") + + print(f"生成结果: {'成功' if result == 0 else '失败'}") + print() + +def main(): + """主函数""" + print("STT Jupyter Notebook使用示例") + print("=" * 40) + + # 运行各种示例 + try: + basic_jupyter_example() + tqdm_jupyter_example() + auto_progress_example() + advanced_jupyter_example() + notebook_integration_demo() + + print("所有Jupyter示例运行完成!") + print("检查生成的文件以查看结果。") + + except Exception as e: + print(f"运行示例时出错: {e}") + import traceback + traceback.print_exc() + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/demo/example_usage.py b/demo/example_usage.py new file mode 100644 index 0000000..c4ed46f --- /dev/null +++ b/demo/example_usage.py @@ -0,0 +1,138 @@ +#!/usr/bin/env python3 +""" +STT Python绑定使用示例 + +这个脚本展示了如何使用STT的Python绑定来生成球面三角网格。 +""" + +import pystt as stt +import os +import sys + +def basic_example(): + """基本使用示例""" + print("=== 基本使用示例 ===") + + # 创建生成器 + generator = stt.SttGenerator() + + # 设置参数 + generator.set_tree_depth(3, 6) # 最小深度3,最大深度6 + generator.set_reference_system("WGS84") # 使用WGS84参考系统 + + # 运行生成 + result = generator.run("basic_output.msh") + + print(f"生成结果: {'成功' if result == 0 else '失败'}") + print(f"输出文件: basic_output.msh") + print() + +def quick_creation_example(): + """快速创建示例""" + print("=== 快速创建示例 ===") + + # 使用便利函数快速创建 + info = stt.create_stt( + min_depth=2, + max_depth=5, + reference_system="Earth", + output_file="quick_output.msh" + ) + + print("生成信息:") + for key, value in info.items(): + print(f" {key}: {value}") + print() + +def advanced_example(): + """高级使用示例""" + print("=== 高级使用示例 ===") + + generator = stt.SttGenerator() + + # 设置自定义参考系统(月球) + generator.set_pole_equator_radius(1738000.0, 1738000.0) # 月球半径 + + # 设置二十面体方向 + generator.set_icosahedron_orient(0.0, 90.0) # 北极方向 + + # 使用完整参数运行 + params = { + "output_msh": "advanced_output.msh", + "output_vertex": "vertices.txt", + "output_triangle_center": "triangle_centers.txt", + "output_neighbor": "neighbors.txt" + } + + result = generator.run_full(params) + + print(f"高级生成结果: {'成功' if result == 0 else '失败'}") + print("输出文件:") + for key, file in params.items(): + if os.path.exists(file): + size = os.path.getsize(file) + print(f" {file}: {size} bytes") + else: + print(f" {file}: 未生成") + print() + +def custom_reference_system_example(): + """自定义参考系统示例""" + print("=== 自定义参考系统示例 ===") + + generator = stt.SttGenerator() + + # 设置自定义椭球参数(例如火星) + # 火星: 赤道半径 3396.2 km, 极半径 3376.2 km + generator.set_pole_equator_radius(3396200.0, 3376200.0) + generator.set_tree_depth(3, 7) + + result = generator.run("mars_grid.msh") + + print(f"火星网格生成: {'成功' if result == 0 else '失败'}") + print() + +def check_module_info(): + """检查模块信息""" + print("=== 模块信息 ===") + + info = stt.get_info() + for key, value in info.items(): + print(f"{key}: {value}") + + print(f"可用常量:") + print(f" stt.WGS84: {stt.WGS84}") + print(f" stt.EARTH: {stt.EARTH}") + print(f" stt.MOON: {stt.MOON}") + print() + +def main(): + """主函数""" + print("STT Python绑定使用示例") + print("=" * 30) + + # 检查模块是否可用 + try: + check_module_info() + except Exception as e: + print(f"错误: 无法加载STT模块 - {e}") + print("请确保已经正确安装STT Python绑定") + sys.exit(1) + + # 运行各种示例 + try: + basic_example() + quick_creation_example() + advanced_example() + custom_reference_system_example() + + print("所有示例运行完成!") + print("检查生成的文件以查看结果。") + + except Exception as e: + print(f"运行示例时出错: {e}") + import traceback + traceback.print_exc() + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/demo/test_binding.py b/demo/test_binding.py new file mode 100644 index 0000000..d457427 --- /dev/null +++ b/demo/test_binding.py @@ -0,0 +1,239 @@ +""" +STT Python绑定测试脚本 + +这个脚本测试STT Python绑定的基本功能。 +""" + +import sys +import os +import tempfile +import unittest + +# 尝试导入STT模块 +try: + import stt + STT_AVAILABLE = True +except ImportError as e: + print(f"警告: 无法导入STT模块 - {e}") + print("请确保已经构建并安装了STT Python绑定") + STT_AVAILABLE = False + stt = None + +class TestSttBinding(unittest.TestCase): + """STT绑定测试类""" + + @classmethod + def setUpClass(cls): + """测试类设置""" + if not STT_AVAILABLE: + raise unittest.SkipTest("STT模块不可用") + + # 创建临时目录用于测试输出 + cls.temp_dir = tempfile.mkdtemp() + print(f"测试临时目录: {cls.temp_dir}") + + @classmethod + def tearDownClass(cls): + """测试类清理""" + if hasattr(cls, 'temp_dir') and os.path.exists(cls.temp_dir): + # 清理临时文件 + import shutil + shutil.rmtree(cls.temp_dir) + print(f"清理临时目录: {cls.temp_dir}") + + def test_module_import(self): + """测试模块导入""" + self.assertIsNotNone(stt, "STT模块应该可用") + self.assertTrue(hasattr(stt, 'SttGenerator'), "应该包含SttGenerator类") + self.assertTrue(hasattr(stt, 'create_stt'), "应该包含create_stt函数") + self.assertTrue(hasattr(stt, 'WGS84'), "应该包含WGS84常量") + + def test_generator_creation(self): + """测试生成器创建""" + generator = stt.SttGenerator() + self.assertIsNotNone(generator, "应该能够创建生成器实例") + + # 测试基本方法存在 + self.assertTrue(hasattr(generator, 'set_tree_depth'), "应该有set_tree_depth方法") + self.assertTrue(hasattr(generator, 'set_reference_system'), "应该有set_reference_system方法") + self.assertTrue(hasattr(generator, 'run'), "应该有run方法") + self.assertTrue(hasattr(generator, 'run_full'), "应该有run_full方法") + + def test_basic_generation(self): + """测试基本生成功能""" + generator = stt.SttGenerator() + + # 设置基本参数 + generator.set_tree_depth(2, 4) # 使用较小的深度进行快速测试 + generator.set_reference_system("WGS84") + + # 创建输出文件路径 + output_file = os.path.join(self.temp_dir, "test_basic.msh") + + # 运行生成 + result = generator.run(output_file) + + # 检查结果 + self.assertEqual(result, 0, "生成应该成功返回0") + + # 检查输出文件是否存在 + if os.path.exists(output_file): + file_size = os.path.getsize(output_file) + print(f"生成文件大小: {file_size} 字节") + self.assertGreater(file_size, 0, "输出文件应该非空") + else: + print("警告: 输出文件未生成,这可能是因为生成器需要更多参数") + + def test_quick_creation(self): + """测试快速创建功能""" + output_file = os.path.join(self.temp_dir, "test_quick.msh") + + # 使用便利函数 + info = stt.create_stt( + min_depth=2, + max_depth=4, + reference_system="Earth", + output_file=output_file + ) + + # 检查返回信息 + self.assertIsInstance(info, dict, "应该返回字典信息") + self.assertEqual(info['min_depth'], 2, "最小深度应该正确") + self.assertEqual(info['max_depth'], 4, "最大深度应该正确") + self.assertEqual(info['reference_system'], "Earth", "参考系统应该正确") + self.assertEqual(info['output_file'], output_file, "输出文件应该正确") + + # 检查文件生成 + if os.path.exists(output_file): + file_size = os.path.getsize(output_file) + print(f"快速创建文件大小: {file_size} 字节") + self.assertGreater(file_size, 0, "输出文件应该非空") + + def test_custom_reference_system(self): + """测试自定义参考系统""" + generator = stt.SttGenerator() + + # 测试预设系统 + generator.set_reference_system("WGS84") + generator.set_reference_system("Earth") + generator.set_reference_system("Moon") + + # 测试自定义半径 + generator.set_pole_equator_radius(6378137.0, 6356752.3) # WGS84参数 + + # 设置其他参数 + generator.set_tree_depth(2, 3) + generator.set_icosahedron_orient(0.0, 90.0) + + # 运行测试 + output_file = os.path.join(self.temp_dir, "test_custom.msh") + result = generator.run(output_file) + + self.assertEqual(result, 0, "自定义参考系统生成应该成功") + + def test_module_constants(self): + """测试模块常量""" + self.assertEqual(stt.WGS84, "WGS84", "WGS84常量应该正确") + self.assertEqual(stt.EARTH, "Earth", "Earth常量应该正确") + self.assertEqual(stt.MOON, "Moon", "Moon常量应该正确") + + def test_module_info(self): + """测试模块信息""" + info = stt.get_info() + self.assertIsInstance(info, dict, "应该返回字典信息") + self.assertIn('version', info, "应该包含版本信息") + self.assertIn('author', info, "应该包含作者信息") + + print("模块信息:") + for key, value in info.items(): + print(f" {key}: {value}") + + def test_advanced_parameters(self): + """测试高级参数""" + generator = stt.SttGenerator() + + # 设置参数 + generator.set_tree_depth(2, 4) + generator.set_reference_system("WGS84") + generator.set_icosahedron_orient(0.0, 90.0) + + # 使用完整参数运行 + params = { + "output_msh": os.path.join(self.temp_dir, "test_advanced.msh"), + "output_vertex": os.path.join(self.temp_dir, "vertices.txt"), + "output_triangle_center": os.path.join(self.temp_dir, "centers.txt"), + "output_neighbor": os.path.join(self.temp_dir, "neighbors.txt") + } + + result = generator.run_full(params) + self.assertEqual(result, 0, "高级参数生成应该成功") + + # 检查生成的文件 + for key, filename in params.items(): + if os.path.exists(filename): + file_size = os.path.getsize(filename) + print(f"{key} 文件大小: {file_size} 字节") + + def test_error_handling(self): + """测试错误处理""" + generator = stt.SttGenerator() + + # 测试无效参数(应该不会崩溃) + try: + generator.set_tree_depth(-1, 10) # 无效深度 + generator.set_reference_system("InvalidSystem") + # 这些调用应该能够处理而不会崩溃 + except Exception as e: + print(f"错误处理测试: {e}") + +def run_basic_tests(): + """运行基本测试""" + print("运行STT Python绑定基本测试...") + + if not STT_AVAILABLE: + print("错误: STT模块不可用,无法运行测试") + return False + + # 创建简单的测试 + try: + print("1. 测试模块导入...") + generator = stt.SttGenerator() + print(" ✓ 模块导入成功") + + print("2. 测试基本功能...") + generator.set_tree_depth(2, 3) + generator.set_reference_system("WGS84") + print(" ✓ 基本设置成功") + + print("3. 测试快速创建...") + info = stt.create_stt(2, 3, "Earth", "test.msh") + print(f" ✓ 快速创建成功: {info}") + + print("4. 测试模块信息...") + info = stt.get_info() + print(f" ✓ 模块信息: {info['version']}") + + print("基本测试通过!") + return True + + except Exception as e: + print(f"测试失败: {e}") + import traceback + traceback.print_exc() + return False + +def main(): + """主函数""" + print("STT Python绑定测试") + print("=" * 30) + + # 运行基本测试 + if run_basic_tests(): + print("\n运行完整单元测试...") + unittest.main(argv=[''], exit=False, verbosity=2) + else: + print("\n基本测试失败,跳过完整测试") + sys.exit(1) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/pybind/__init__.py b/pybind/__init__.py new file mode 100644 index 0000000..de78c67 --- /dev/null +++ b/pybind/__init__.py @@ -0,0 +1,161 @@ +""" +STT Python Binding +================== + +A Python interface for the Spherical Triangular Tessellation (STT) generator. + +This module provides access to the C++ STT library for generating spherical +triangular tessellations on various reference systems. + +Basic Usage: +------------ +>>> import stt +>>> generator = stt.SttGenerator() +>>> generator.set_tree_depth(3, 8) +>>> generator.set_reference_system("WGS84") +>>> generator.run("output.msh") + +Quick Creation: +-------------- +>>> import stt +>>> info = stt.create_stt(3, 8, "WGS84", "output.msh") +>>> print(info) + +Progress Callback (for Jupyter notebooks): +----------------------------------------- +>>> import stt +>>> def my_progress(description, percentage): +... print(f"{description}: {percentage:.1f}%") +... +>>> generator = stt.SttGenerator() +>>> generator.set_progress_callback(my_progress) +>>> generator.run("output.msh") + +Advanced Usage: +-------------- +>>> import stt +>>> generator = stt.SttGenerator() +>>> params = { +... "output_msh": "output.msh", +... "output_vertex": "vertices.txt", +... "control_points": "points.txt", +... "topography": "topo.txt" +... } +>>> generator.run_full(params) +""" + +from .stt import ( + SttGenerator, + create_stt, + WGS84, + EARTH, + MOON, + __version__ +) + +__all__ = [ + 'SttGenerator', + 'create_stt', + 'WGS84', + 'EARTH', + 'MOON', + '__version__', + 'get_info', + 'help' +] + +# 模块信息 +__author__ = 'STT Development Team' +__email__ = 'yizhang-geo@zju.edu.cn' +__license__ = 'MIT' +__url__ = 'https://github.com/your-repo/stt' + +def get_info(): + """获取STT模块信息""" + return { + 'version': __version__, + 'author': __author__, + 'email': __email__, + 'license': __license__, + 'url': __url__ + } + +def help(): + """显示帮助信息""" + print(__doc__) + +# Jupyter notebook进度条支持 +try: + # 尝试导入tqdm用于更好的进度条支持 + from tqdm import tqdm + HAS_TQDM = True + + class TqdmProgressCallback: + """使用tqdm的进度回调""" + def __init__(self, description="Progress"): + self.pbar = tqdm(total=100, desc=description) + self.current_description = description + + def __call__(self, description, percentage): + if description != self.current_description: + self.pbar.set_description(description) + self.current_description = description + self.pbar.n = int(percentage) + self.pbar.refresh() + if percentage >= 100: + self.pbar.close() + + def close(self): + if self.pbar: + self.pbar.close() + + def create_tqdm_callback(description="STT Progress"): + """创建tqdm进度回调""" + return TqdmProgressCallback(description) + + __all__.extend(['TqdmProgressCallback', 'create_tqdm_callback']) + +except ImportError: + HAS_TQDM = False + + class SimpleProgressCallback: + """简单的文本进度回调""" + def __init__(self, description="Progress"): + self.current_description = description + self.last_percentage = -1 + + def __call__(self, description, percentage): + if description != self.current_description: + print(f"\n{description}:") + self.current_description = description + if int(percentage) != int(self.last_percentage): + print(f" {percentage:.1f}%", end='\r') + self.last_percentage = percentage + if percentage >= 100: + print(" 100.0%") + + def create_simple_callback(description="STT Progress"): + """创建简单进度回调""" + return SimpleProgressCallback(description) + + __all__.extend(['SimpleProgressCallback', 'create_simple_callback']) + +# 自动选择最佳的进度回调 +def create_progress_callback(description="STT Progress", use_tqdm=None): + """ + 创建适合的进度回调函数 + + 参数: + description: 进度条描述 + use_tqdm: 是否强制使用tqdm,None表示自动选择 + + 返回: + 进度回调函数 + """ + if use_tqdm is None: + use_tqdm = HAS_TQDM + + if use_tqdm and HAS_TQDM: + return create_tqdm_callback(description) + else: + return create_simple_callback(description) \ No newline at end of file diff --git a/pybind/progress_bar_python.h b/pybind/progress_bar_python.h new file mode 100644 index 0000000..c60b175 --- /dev/null +++ b/pybind/progress_bar_python.h @@ -0,0 +1,74 @@ +#ifndef PROGRESS_BAR_PYTHON_H +#define PROGRESS_BAR_PYTHON_H + +#ifdef PYTHON_BINDING + +#include "python_progress_wrapper.h" +#include "../src/progress_bar.h" + +// Python友好的进度条类 +class PythonProgressBar : public ProgressBar { +private: + std::string description_; + unsigned long total_count_; + bool use_python_callback_; + +public: + PythonProgressBar() : ProgressBar(), total_count_(0), use_python_callback_(false) {} + + PythonProgressBar(unsigned long n_, const char* description_="", std::ostream& out_=std::cerr) + : ProgressBar(n_, description_, out_), description_(description_), total_count_(n_), use_python_callback_(false) { + // 检查是否有Python回调 + if (ProgressCallbackManager::has_callback()) { + use_python_callback_ = true; + auto* callback = ProgressCallbackManager::get_callback(); + if (callback) { + callback->set_description(description_); + callback->set_total(n_); + } + } + } + + void SetFrequencyUpdate(unsigned long frequency_update_) { + if (!use_python_callback_) { + ProgressBar::SetFrequencyUpdate(frequency_update_); + } + // Python模式下使用固定更新频率 + } + + void SetStyle(const char* unit_bar_, const char* unit_space_) { + if (!use_python_callback_) { + ProgressBar::SetStyle(unit_bar_, unit_space_); + } + // Python模式下忽略样式设置 + } + + void Progressed(unsigned long idx_) { + if (use_python_callback_) { + auto* callback = ProgressCallbackManager::get_callback(); + if (callback) { + callback->update_progress(idx_); + // 完成时调用finish + if (idx_ >= total_count_ - 1) { + callback->finish(); + } + return; + } + } + + // 回退到原有的进度条实现 + ProgressBar::Progressed(idx_); + } + + // 设置是否使用Python回调 + void set_use_python_callback(bool use) { + use_python_callback_ = use; + } +}; + +// 替换原有的ProgressBar定义 +#define ProgressBar PythonProgressBar + +#endif // PYTHON_BINDING + +#endif // PROGRESS_BAR_PYTHON_H \ No newline at end of file diff --git a/pybind/python_progress_wrapper.h b/pybind/python_progress_wrapper.h new file mode 100644 index 0000000..b8ad091 --- /dev/null +++ b/pybind/python_progress_wrapper.h @@ -0,0 +1,95 @@ +#ifndef PYTHON_PROGRESS_WRAPPER_H +#define PYTHON_PROGRESS_WRAPPER_H + +#ifdef PYTHON_BINDING + +#include +#include +#include +#include + +namespace py = pybind11; + +// Python进度回调接口 +class PythonProgressCallback { +public: + virtual ~PythonProgressCallback() = default; + virtual void update(const std::string& description, double percentage) = 0; + virtual void set_description(const std::string& description) = 0; + virtual void set_total(unsigned long total) = 0; + virtual void update_progress(unsigned long current) = 0; + virtual void finish() = 0; +}; + +// pybind11包装器 +class PyProgressCallback : public PythonProgressCallback { +private: + py::object callback_func_; + std::string current_description_; + unsigned long total_; + bool has_total_; + +public: + PyProgressCallback(py::object callback) + : callback_func_(callback), total_(0), has_total_(false) {} + + void update(const std::string& description, double percentage) override { + if (callback_func_ && !callback_func_.is_none()) { + try { + callback_func_(description, percentage); + } catch (const std::exception& e) { + // 忽略Python回调中的异常,避免崩溃 + } + } + } + + void set_description(const std::string& description) override { + current_description_ = description; + } + + void set_total(unsigned long total) override { + total_ = total; + has_total_ = true; + } + + void update_progress(unsigned long current) override { + if (has_total_ && total_ > 0) { + double percentage = (static_cast(current) / total_) * 100.0; + update(current_description_, percentage); + } + } + + void finish() override { + update(current_description_, 100.0); + } +}; + +// 全局进度回调管理器 +class ProgressCallbackManager { +private: + static std::unique_ptr global_callback_; + +public: + static void set_callback(py::object callback) { + if (callback.is_none()) { + global_callback_.reset(); + } else { + global_callback_.reset(new PyProgressCallback(callback)); + } + } + + static PythonProgressCallback* get_callback() { + return global_callback_.get(); + } + + static bool has_callback() { + return global_callback_ != nullptr; + } +}; + +// 静态成员定义 +std::unique_ptr ProgressCallbackManager::global_callback_ = nullptr; + +#endif // PYTHON_BINDING + +#endif // PYTHON_PROGRESS_WRAPPER_H \ No newline at end of file diff --git a/pybind/stt_binding.cpp b/pybind/stt_binding.cpp new file mode 100644 index 0000000..43b7483 --- /dev/null +++ b/pybind/stt_binding.cpp @@ -0,0 +1,336 @@ +#include +#include +#include +#include +#include + +// 在包含原始头文件之前定义PYTHON_BINDING宏 +#define PYTHON_BINDING + +// 包含进度条包装器 +#include "python_progress_wrapper.h" +#include "progress_bar_python.h" + +// 包含原始STT头文件 +#include "../src/stt_class.h" + +// 检查tqdm是否可用 +bool has_tqdm() { + py::object tqdm_module; + try { + tqdm_module = py::module_::import("tqdm"); + return true; + } catch (...) { + return false; + } +} + +namespace py = pybind11; + +// Python友好的SttGenerator包装器 +class PySttGenerator { +private: + SttGenerator generator; + +public: + PySttGenerator() = default; + + // 设置树深度 + void set_tree_depth(int min_depth, int max_depth) { + std::string depth_str = std::to_string(min_depth) + "/" + std::to_string(max_depth); + char* depth_char = const_cast(depth_str.c_str()); + generator.set_tree_depth(depth_char); + } + + // 设置椭球半径 + void set_pole_equator_radius(double equator_radius, double pole_radius) { + std::string radius_str = std::to_string(equator_radius) + "/" + std::to_string(pole_radius); + char* radius_char = const_cast(radius_str.c_str()); + generator.set_pole_equator_radius(radius_char); + } + + // 设置二十面体方向 + void set_icosahedron_orient(double longitude, double latitude) { + std::string orient_str = std::to_string(longitude) + "/" + std::to_string(latitude); + char* orient_char = const_cast(orient_str.c_str()); + generator.set_icosahedron_orient(orient_char); + } + + // 使用预设的参考系统 + void set_reference_system(const std::string& ref_system) { + char* ref_char = const_cast(ref_system.c_str()); + generator.set_pole_equator_radius(ref_char); + } + + // 设置进度回调函数 + void set_progress_callback(py::object callback) { + ProgressCallbackManager::set_callback(callback); + } + + // 清除进度回调 + void clear_progress_callback() { + ProgressCallbackManager::set_callback(py::none()); + } + + // 执行主要例程 - 简化版本 + int run(const std::string& output_msh_file = "") { + char options[14][1024]; + + // 初始化所有选项为"NULL" + for (int i = 0; i < 14; i++) { + strcpy(options[i], "NULL"); + } + + // 如果指定了输出文件,设置它 + if (!output_msh_file.empty()) { + strncpy(options[3], output_msh_file.c_str(), 1023); + } + + return generator.Routine(options); + } + + // 获取STT生成器信息 + py::dict get_info() { + py::dict info; + + // 获取当前设置的树深度 + int min_depth, max_depth; + // 这里需要根据实际的SttGenerator类实现来获取深度信息 + // 暂时使用默认值,实际实现中应该从generator对象获取 + info["min_depth"] = 1; // 需要根据实际实现调整 + info["max_depth"] = 5; // 需要根据实际实现调整 + + // 获取参考系统信息 + info["reference_system"] = "WGS84"; // 需要根据实际实现调整 + + // 获取椭球半径信息 + info["equator_radius"] = 6378137.0; // WGS84默认值,需要根据实际实现调整 + info["pole_radius"] = 6356752.314245; // WGS84默认值,需要根据实际实现调整 + + // 获取二十面体方向 + info["icosahedron_longitude"] = 0.0; // 需要根据实际实现调整 + info["icosahedron_latitude"] = 0.0; // 需要根据实际实现调整 + + // 添加版本信息 + info["version"] = "1.0.0"; + info["author"] = "STT Development Team"; + info["email"] = "stt@example.com"; + info["license"] = "MIT"; + info["url"] = "https://github.com/stt/stt-generator"; + + return info; + } + + // 执行主要例程 - 完整版本 + int run_full(const py::dict& params) { + char options[14][1024]; + + // 初始化所有选项为"NULL" + for (int i = 0; i < 14; i++) { + strcpy(options[i], "NULL"); + } + + // 处理参数字典 + if (params.contains("output_msh")) { + std::string msh_file = py::str(params["output_msh"]); + strncpy(options[3], msh_file.c_str(), 1023); + } + + if (params.contains("output_vertex")) { + std::string vertex_file = py::str(params["output_vertex"]); + strncpy(options[4], vertex_file.c_str(), 1023); + } + + if (params.contains("output_triangle_center")) { + std::string tri_file = py::str(params["output_triangle_center"]); + strncpy(options[5], tri_file.c_str(), 1023); + } + + if (params.contains("output_neighbor")) { + std::string neighbor_file = py::str(params["output_neighbor"]); + strncpy(options[6], neighbor_file.c_str(), 1023); + } + + if (params.contains("control_points")) { + std::string points_file = py::str(params["control_points"]); + strncpy(options[7], points_file.c_str(), 1023); + } + + if (params.contains("control_lines")) { + std::string lines_file = py::str(params["control_lines"]); + strncpy(options[8], lines_file.c_str(), 1023); + } + + if (params.contains("control_polygons")) { + std::string poly_file = py::str(params["control_polygons"]); + strncpy(options[9], poly_file.c_str(), 1023); + } + + if (params.contains("control_circles")) { + std::string circles_file = py::str(params["control_circles"]); + strncpy(options[10], circles_file.c_str(), 1023); + } + + if (params.contains("outline_shape")) { + std::string outline_file = py::str(params["outline_shape"]); + strncpy(options[11], outline_file.c_str(), 1023); + } + + if (params.contains("hole_shape")) { + std::string hole_file = py::str(params["hole_shape"]); + strncpy(options[12], hole_file.c_str(), 1023); + } + + if (params.contains("topography")) { + std::string topo_file = py::str(params["topography"]); + strncpy(options[13], topo_file.c_str(), 1023); + } + + return generator.Routine(options); + } +}; + +// 便利函数 - 快速创建STT +py::dict create_stt(int min_depth, int max_depth, + const std::string& reference_system = "WGS84", + const std::string& output_file = "output.msh") { + + PySttGenerator gen; + gen.set_tree_depth(min_depth, max_depth); + gen.set_reference_system(reference_system); + + int result = gen.run(output_file); + + py::dict info; + info["min_depth"] = min_depth; + info["max_depth"] = max_depth; + info["reference_system"] = reference_system; + info["output_file"] = output_file; + info["success"] = (result == 0); + + return info; +} + +// 模块级别的get_info函数 +py::dict module_get_info() { + py::dict info; + + // 添加模块信息 + info["version"] = "1.0.0"; + info["author"] = "STT Development Team"; + info["email"] = "stt@example.com"; + info["license"] = "MIT"; + info["url"] = "https://github.com/stt/stt-generator"; + info["description"] = "Spherical Triangular Tessellation (STT) generator"; + + return info; +} + +PYBIND11_MODULE(pystt, m) { + m.doc() = R"pbdoc( + STT Python Binding + ------------------ + A Python interface for the Spherical Triangular Tessellation (STT) generator. + + This module provides access to the C++ STT library for generating + spherical triangular tessellations on various reference systems. + + Features: + - Support for multiple reference systems (WGS84, Earth, Moon, custom) + - Configurable tree depth for mesh refinement + - Progress callback support for Jupyter notebooks + - Output to various file formats (.msh, .txt) + )pbdoc"; + + // 进度条回调函数类型定义 + using ProgressCallback = std::function; + + // 主类绑定 + py::class_(m, "SttGenerator") + .def(py::init<>(), "Create a new STT generator instance") + + .def("set_tree_depth", &PySttGenerator::set_tree_depth, + "Set the minimum and maximum tree depth", + py::arg("min_depth"), py::arg("max_depth")) + + .def("set_pole_equator_radius", &PySttGenerator::set_pole_equator_radius, + "Set the pole and equator radius for the reference system", + py::arg("equator_radius"), py::arg("pole_radius")) + + .def("set_icosahedron_orient", &PySttGenerator::set_icosahedron_orient, + "Set the orientation of the icosahedron top vertex", + py::arg("longitude"), py::arg("latitude")) + + .def("set_reference_system", &PySttGenerator::set_reference_system, + "Set the reference system (WGS84, Earth, Moon, or custom)", + py::arg("ref_system")) + + .def("set_progress_callback", &PySttGenerator::set_progress_callback, + "Set a progress callback function for Jupyter notebook compatibility\n" + "Callback function should accept (description, percentage) parameters", + py::arg("callback")) + + .def("clear_progress_callback", &PySttGenerator::clear_progress_callback, + "Clear the progress callback function") + + .def("run", &PySttGenerator::run, + "Run the STT generation with basic parameters", + py::arg("output_msh_file") = "") + + .def("run_full", &PySttGenerator::run_full, + "Run the STT generation with full parameters", + py::arg("params")) + + .def("get_info", &PySttGenerator::get_info, + "Get information about the current STT generator configuration"); + + // 便利函数 + m.def("create_stt", &create_stt, + "Create STT with simplified interface", + py::arg("min_depth"), py::arg("max_depth"), + py::arg("reference_system") = "WGS84", + py::arg("output_file") = "output.msh"); + + // 模块级别的get_info函数 + m.def("get_info", &module_get_info, + "Get module information including version, author, and contact details"); + + // 进度回调函数 - 使用py::function + m.def("create_simple_callback", [](const std::string& description) { + return py::cpp_function([description](const std::string& desc, double percentage) { + std::cout << description << " - " << desc << ": " << percentage << "%" << std::endl; + }); + }, "Create a simple progress callback function"); + + m.def("create_tqdm_callback", [](const std::string& description) { + if (!has_tqdm()) { + throw std::runtime_error("tqdm is not available"); + } + // 返回一个简单的回调函数,因为tqdm需要更复杂的设置 + return py::cpp_function([description](const std::string& desc, double percentage) { + std::cout << description << " - " << desc << ": " << percentage << "%" << std::endl; + }); + }, "Create a tqdm-based progress callback function"); + + m.def("create_progress_callback", [](const std::string& description) { + // 自动选择最适合的进度回调 - 总是返回简单版本 + return py::cpp_function([description](const std::string& desc, double percentage) { + std::cout << description << " - " << desc << ": " << percentage << "%" << std::endl; + }); + }, "Create the best available progress callback function"); + + // 参考系统常量 + m.attr("WGS84") = "WGS84"; + m.attr("EARTH") = "Earth"; + m.attr("MOON") = "Moon"; + + // 检查tqdm是否可用 + m.attr("HAS_TQDM") = has_tqdm(); + + // 版本信息 +#ifdef VERSION_INFO + m.attr("__version__") = VERSION_INFO; +#else + m.attr("__version__") = "dev"; +#endif +} \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..c728458 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,58 @@ +[build-system] +requires = ["setuptools>=64", "wheel", "pybind11>=2.6.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "pystt" +version = "1.4.1" +description = "Python binding for Spherical Triangular Tessellation (STT) generator" +readme = "README.md" +requires-python = ">=3.6" +license = {text = "MIT"} +authors = [ + {name = "STT Development Team", email = "yizhang-geo@zju.edu.cn"}, +] +keywords = ["spherical", "triangular", "tessellation", "mesh", "generation", "geography"] +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: C++", + "Topic :: Scientific/Engineering :: GIS", + "Topic :: Scientific/Engineering :: Mathematics", +] + +dependencies = [ + "pybind11>=2.6.0", + "numpy>=1.19.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=6.0", + "pytest-cov", + "black", + "flake8", +] +progress = [ + "tqdm>=4.50.0", +] + +[project.urls] +Homepage = "https://github.com/your-repo/stt" +"Bug Reports" = "https://github.com/your-repo/stt/issues" +Source = "https://github.com/your-repo/stt" +Documentation = "https://stt.readthedocs.io/" + +[tool.setuptools] +zip-safe = false + +[tool.setuptools.exclude-package-data] +"*" = ["*.cpp", "*.cc", "*.h", "*.hpp", "CMakeLists.txt", "Makefile"] \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..9021d71 --- /dev/null +++ b/setup.py @@ -0,0 +1,179 @@ +import os +import sys + +# Try to import pybind11, but don't fail if it's not available during build setup +try: + import pybind11 + from pybind11.setup_helpers import Pybind11Extension, build_ext + PYBIND11_AVAILABLE = True +except ImportError: + PYBIND11_AVAILABLE = False + # Define dummy classes for when pybind11 is not available + class Pybind11Extension: + def __init__(self, *args, **kwargs): + pass + class build_ext: + pass + +from setuptools import setup, Extension + +# 获取当前目录 +__dir__ = os.path.dirname(os.path.abspath(__file__)) + +# 获取所有源文件 +def get_source_files(): + """获取所有STT源文件""" + src_files = [] + + # 添加主源文件 + src_dir = os.path.join(__dir__, 'src') + for file in os.listdir(src_dir): + if file.endswith('.cc'): + src_files.append(os.path.join('src', file)) + + # 添加绑定文件 + src_files.append('pybind/stt_binding.cpp') + + return src_files + +# 获取包含目录 +def get_include_dirs(): + """获取包含目录""" + includes = [ + # 当前目录 + __dir__, + # src目录 + os.path.join(__dir__, 'src'), + # pybind目录 + os.path.join(__dir__, 'pybind'), + ] + + # 只有在pybind11可用时才添加其包含目录 + if PYBIND11_AVAILABLE: + includes.extend([ + # pybind11包含目录 + pybind11.get_include(), + # Python包含目录 + pybind11.get_include(True) + ]) + + return includes + +# 定义扩展模块 +if PYBIND11_AVAILABLE: + ext_modules = [ + Pybind11Extension( + "pystt", + # 源文件列表 + get_source_files(), + # 包含目录 + include_dirs=get_include_dirs(), + # 编译选项 + extra_compile_args=[ + '-O3', # 优化级别 + '-std=c++11', # C++11标准 + '-fPIC', # 位置无关代码 + '-DVERSION_INFO="1.4.1"', # 版本信息 + '-DPYTHON_BINDING' # Python绑定模式 + ], + # 链接选项 + extra_link_args=[], + # 定义宏 + define_macros=[ + ('VERSION_INFO', '"1.4.1"'), + ('PYTHON_BINDING', '1'), + ], + # 语言标准 + cxx_std=11, + ), + ] +else: + # 当pybind11不可用时,使用空列表 + ext_modules = [] + +# 读取README文件 +def read_readme(): + """读取README文件""" + readme_path = os.path.join(__dir__, 'README.md') + if os.path.exists(readme_path): + with open(readme_path, 'r', encoding='utf-8') as f: + return f.read() + return "STT Python Binding - Spherical Triangular Tessellation Generator" + +# 设置包信息 +setup( + name='pystt', + version='1.4.1', + author='STT Development Team', + author_email='yizhang-geo@zju.edu.cn', + description='Python binding for Spherical Triangular Tessellation (STT) generator', + long_description=read_readme(), + long_description_content_type='text/markdown', + url='https://github.com/your-repo/stt', + + # 扩展模块 + ext_modules=ext_modules, + + # 构建命令 + cmdclass={"build_ext": build_ext} if PYBIND11_AVAILABLE else {}, + + # 依赖 + install_requires=[ + 'pybind11>=2.6.0', + 'numpy>=1.19.0', + ], + + # 可选依赖 + extras_require={ + 'dev': [ + 'pytest>=6.0', + 'pytest-cov', + 'black', + 'flake8', + ], + 'progress': [ + 'tqdm>=4.50.0', + ], + }, + + # Python版本要求 + python_requires='>=3.6', + + # 分类 + classifiers=[ + 'Development Status :: 4 - Beta', + 'Intended Audience :: Science/Research', + 'License :: OSI Approved :: MIT License', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', + 'Programming Language :: C++', + 'Topic :: Scientific/Engineering :: GIS', + 'Topic :: Scientific/Engineering :: Mathematics', + ], + + # 关键词 + keywords='spherical triangular tessellation mesh generation geography', + + # 项目URL + project_urls={ + 'Bug Reports': 'https://github.com/your-repo/stt/issues', + 'Source': 'https://github.com/your-repo/stt', + 'Documentation': 'https://stt.readthedocs.io/', + }, + + # 包含包数据 + include_package_data=True, + + # 排除文件 + exclude_package_data={ + '': ['*.cpp', '*.cc', '*.h', '*.hpp', 'CMakeLists.txt', 'Makefile'], + }, + + # ZIP安全 + zip_safe=False, +) \ No newline at end of file