tmp
This commit is contained in:
95
pybind/python_progress_wrapper.h
Normal file
95
pybind/python_progress_wrapper.h
Normal file
@@ -0,0 +1,95 @@
|
||||
#ifndef PYTHON_PROGRESS_WRAPPER_H
|
||||
#define PYTHON_PROGRESS_WRAPPER_H
|
||||
|
||||
#ifdef PYTHON_BINDING
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/functional.h>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
// Python进度回调接口
|
||||
class PythonProgressCallback {
|
||||
public:
|
||||
virtual ~PythonProgressCallback() = default;
|
||||
virtual void update(const std::string& description, double percentage) = 0;
|
||||
virtual void set_description(const std::string& description) = 0;
|
||||
virtual void set_total(unsigned long total) = 0;
|
||||
virtual void update_progress(unsigned long current) = 0;
|
||||
virtual void finish() = 0;
|
||||
};
|
||||
|
||||
// pybind11包装器
|
||||
class PyProgressCallback : public PythonProgressCallback {
|
||||
private:
|
||||
py::object callback_func_;
|
||||
std::string current_description_;
|
||||
unsigned long total_;
|
||||
bool has_total_;
|
||||
|
||||
public:
|
||||
PyProgressCallback(py::object callback)
|
||||
: callback_func_(callback), total_(0), has_total_(false) {}
|
||||
|
||||
void update(const std::string& description, double percentage) override {
|
||||
if (callback_func_ && !callback_func_.is_none()) {
|
||||
try {
|
||||
callback_func_(description, percentage);
|
||||
} catch (const std::exception& e) {
|
||||
// 忽略Python回调中的异常,避免崩溃
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void set_description(const std::string& description) override {
|
||||
current_description_ = description;
|
||||
}
|
||||
|
||||
void set_total(unsigned long total) override {
|
||||
total_ = total;
|
||||
has_total_ = true;
|
||||
}
|
||||
|
||||
void update_progress(unsigned long current) override {
|
||||
if (has_total_ && total_ > 0) {
|
||||
double percentage = (static_cast<double>(current) / total_) * 100.0;
|
||||
update(current_description_, percentage);
|
||||
}
|
||||
}
|
||||
|
||||
void finish() override {
|
||||
update(current_description_, 100.0);
|
||||
}
|
||||
};
|
||||
|
||||
// 全局进度回调管理器
|
||||
class ProgressCallbackManager {
|
||||
private:
|
||||
static std::unique_ptr<PythonProgressCallback> global_callback_;
|
||||
|
||||
public:
|
||||
static void set_callback(py::object callback) {
|
||||
if (callback.is_none()) {
|
||||
global_callback_.reset();
|
||||
} else {
|
||||
global_callback_.reset(new PyProgressCallback(callback));
|
||||
}
|
||||
}
|
||||
|
||||
static PythonProgressCallback* get_callback() {
|
||||
return global_callback_.get();
|
||||
}
|
||||
|
||||
static bool has_callback() {
|
||||
return global_callback_ != nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
// 静态成员定义
|
||||
std::unique_ptr<PythonProgressCallback> ProgressCallbackManager::global_callback_ = nullptr;
|
||||
|
||||
#endif // PYTHON_BINDING
|
||||
|
||||
#endif // PYTHON_PROGRESS_WRAPPER_H
|
||||
Reference in New Issue
Block a user