95 lines
2.6 KiB
C++
95 lines
2.6 KiB
C++
#ifndef PYTHON_PROGRESS_WRAPPER_H
|
||
#define PYTHON_PROGRESS_WRAPPER_H
|
||
|
||
#ifdef PYTHON_BINDING
|
||
|
||
#include <pybind11/pybind11.h>
|
||
#include <pybind11/functional.h>
|
||
#include <string>
|
||
#include <memory>
|
||
|
||
namespace py = pybind11;
|
||
|
||
// Python进度回调接口
|
||
class PythonProgressCallback {
|
||
public:
|
||
virtual ~PythonProgressCallback() = default;
|
||
virtual void update(const std::string& description, double percentage) = 0;
|
||
virtual void set_description(const std::string& description) = 0;
|
||
virtual void set_total(unsigned long total) = 0;
|
||
virtual void update_progress(unsigned long current) = 0;
|
||
virtual void finish() = 0;
|
||
};
|
||
|
||
// pybind11包装器
|
||
class PyProgressCallback : public PythonProgressCallback {
|
||
private:
|
||
py::object callback_func_;
|
||
std::string current_description_;
|
||
unsigned long total_;
|
||
bool has_total_;
|
||
|
||
public:
|
||
PyProgressCallback(py::object callback)
|
||
: callback_func_(callback), total_(0), has_total_(false) {}
|
||
|
||
void update(const std::string& description, double percentage) override {
|
||
if (callback_func_ && !callback_func_.is_none()) {
|
||
try {
|
||
callback_func_(description, percentage);
|
||
} catch (const std::exception& e) {
|
||
// 忽略Python回调中的异常,避免崩溃
|
||
}
|
||
}
|
||
}
|
||
|
||
void set_description(const std::string& description) override {
|
||
current_description_ = description;
|
||
}
|
||
|
||
void set_total(unsigned long total) override {
|
||
total_ = total;
|
||
has_total_ = true;
|
||
}
|
||
|
||
void update_progress(unsigned long current) override {
|
||
if (has_total_ && total_ > 0) {
|
||
double percentage = (static_cast<double>(current) / total_) * 100.0;
|
||
update(current_description_, percentage);
|
||
}
|
||
}
|
||
|
||
void finish() override {
|
||
update(current_description_, 100.0);
|
||
}
|
||
};
|
||
|
||
// 全局进度回调管理器
|
||
class ProgressCallbackManager {
|
||
private:
|
||
static std::unique_ptr<PythonProgressCallback> global_callback_;
|
||
|
||
public:
|
||
static void set_callback(py::object callback) {
|
||
if (callback.is_none()) {
|
||
global_callback_.reset();
|
||
} else {
|
||
global_callback_.reset(new PyProgressCallback(callback));
|
||
}
|
||
}
|
||
|
||
static PythonProgressCallback* get_callback() {
|
||
return global_callback_.get();
|
||
}
|
||
|
||
static bool has_callback() {
|
||
return global_callback_ != nullptr;
|
||
}
|
||
};
|
||
|
||
// 静态成员定义
|
||
std::unique_ptr<PythonProgressCallback> ProgressCallbackManager::global_callback_ = nullptr;
|
||
|
||
#endif // PYTHON_BINDING
|
||
|
||
#endif // PYTHON_PROGRESS_WRAPPER_H
|