Files
stt/pybind/python_progress_wrapper.h
2025-11-27 15:06:01 +08:00

95 lines
2.6 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#ifndef PYTHON_PROGRESS_WRAPPER_H
#define PYTHON_PROGRESS_WRAPPER_H
#ifdef PYTHON_BINDING
#include <pybind11/pybind11.h>
#include <pybind11/functional.h>
#include <string>
#include <memory>
namespace py = pybind11;
// Python进度回调接口
class PythonProgressCallback {
public:
virtual ~PythonProgressCallback() = default;
virtual void update(const std::string& description, double percentage) = 0;
virtual void set_description(const std::string& description) = 0;
virtual void set_total(unsigned long total) = 0;
virtual void update_progress(unsigned long current) = 0;
virtual void finish() = 0;
};
// pybind11包装器
class PyProgressCallback : public PythonProgressCallback {
private:
py::object callback_func_;
std::string current_description_;
unsigned long total_;
bool has_total_;
public:
PyProgressCallback(py::object callback)
: callback_func_(callback), total_(0), has_total_(false) {}
void update(const std::string& description, double percentage) override {
if (callback_func_ && !callback_func_.is_none()) {
try {
callback_func_(description, percentage);
} catch (const std::exception& e) {
// 忽略Python回调中的异常避免崩溃
}
}
}
void set_description(const std::string& description) override {
current_description_ = description;
}
void set_total(unsigned long total) override {
total_ = total;
has_total_ = true;
}
void update_progress(unsigned long current) override {
if (has_total_ && total_ > 0) {
double percentage = (static_cast<double>(current) / total_) * 100.0;
update(current_description_, percentage);
}
}
void finish() override {
update(current_description_, 100.0);
}
};
// 全局进度回调管理器
class ProgressCallbackManager {
private:
static std::unique_ptr<PythonProgressCallback> global_callback_;
public:
static void set_callback(py::object callback) {
if (callback.is_none()) {
global_callback_.reset();
} else {
global_callback_.reset(new PyProgressCallback(callback));
}
}
static PythonProgressCallback* get_callback() {
return global_callback_.get();
}
static bool has_callback() {
return global_callback_ != nullptr;
}
};
// 静态成员定义
std::unique_ptr<PythonProgressCallback> ProgressCallbackManager::global_callback_ = nullptr;
#endif // PYTHON_BINDING
#endif // PYTHON_PROGRESS_WRAPPER_H