95 lines
2.6 KiB
C
95 lines
2.6 KiB
C
|
|
#ifndef PYTHON_PROGRESS_WRAPPER_H
|
|||
|
|
#define PYTHON_PROGRESS_WRAPPER_H
|
|||
|
|
|
|||
|
|
#ifdef PYTHON_BINDING
|
|||
|
|
|
|||
|
|
#include <pybind11/pybind11.h>
|
|||
|
|
#include <pybind11/functional.h>
|
|||
|
|
#include <string>
|
|||
|
|
#include <memory>
|
|||
|
|
|
|||
|
|
namespace py = pybind11;
|
|||
|
|
|
|||
|
|
// Python进度回调接口
|
|||
|
|
class PythonProgressCallback {
|
|||
|
|
public:
|
|||
|
|
virtual ~PythonProgressCallback() = default;
|
|||
|
|
virtual void update(const std::string& description, double percentage) = 0;
|
|||
|
|
virtual void set_description(const std::string& description) = 0;
|
|||
|
|
virtual void set_total(unsigned long total) = 0;
|
|||
|
|
virtual void update_progress(unsigned long current) = 0;
|
|||
|
|
virtual void finish() = 0;
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
// pybind11包装器
|
|||
|
|
class PyProgressCallback : public PythonProgressCallback {
|
|||
|
|
private:
|
|||
|
|
py::object callback_func_;
|
|||
|
|
std::string current_description_;
|
|||
|
|
unsigned long total_;
|
|||
|
|
bool has_total_;
|
|||
|
|
|
|||
|
|
public:
|
|||
|
|
PyProgressCallback(py::object callback)
|
|||
|
|
: callback_func_(callback), total_(0), has_total_(false) {}
|
|||
|
|
|
|||
|
|
void update(const std::string& description, double percentage) override {
|
|||
|
|
if (callback_func_ && !callback_func_.is_none()) {
|
|||
|
|
try {
|
|||
|
|
callback_func_(description, percentage);
|
|||
|
|
} catch (const std::exception& e) {
|
|||
|
|
// 忽略Python回调中的异常,避免崩溃
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
void set_description(const std::string& description) override {
|
|||
|
|
current_description_ = description;
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
void set_total(unsigned long total) override {
|
|||
|
|
total_ = total;
|
|||
|
|
has_total_ = true;
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
void update_progress(unsigned long current) override {
|
|||
|
|
if (has_total_ && total_ > 0) {
|
|||
|
|
double percentage = (static_cast<double>(current) / total_) * 100.0;
|
|||
|
|
update(current_description_, percentage);
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
void finish() override {
|
|||
|
|
update(current_description_, 100.0);
|
|||
|
|
}
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
// 全局进度回调管理器
|
|||
|
|
class ProgressCallbackManager {
|
|||
|
|
private:
|
|||
|
|
static std::unique_ptr<PythonProgressCallback> global_callback_;
|
|||
|
|
|
|||
|
|
public:
|
|||
|
|
static void set_callback(py::object callback) {
|
|||
|
|
if (callback.is_none()) {
|
|||
|
|
global_callback_.reset();
|
|||
|
|
} else {
|
|||
|
|
global_callback_.reset(new PyProgressCallback(callback));
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
static PythonProgressCallback* get_callback() {
|
|||
|
|
return global_callback_.get();
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
static bool has_callback() {
|
|||
|
|
return global_callback_ != nullptr;
|
|||
|
|
}
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
// 静态成员定义
|
|||
|
|
std::unique_ptr<PythonProgressCallback> ProgressCallbackManager::global_callback_ = nullptr;
|
|||
|
|
|
|||
|
|
#endif // PYTHON_BINDING
|
|||
|
|
|
|||
|
|
#endif // PYTHON_PROGRESS_WRAPPER_H
|