Files
stt/pybind/progress_bar_python.h
2025-11-27 15:06:01 +08:00

74 lines
2.2 KiB
C++

#ifndef PROGRESS_BAR_PYTHON_H
#define PROGRESS_BAR_PYTHON_H
#ifdef PYTHON_BINDING
#include "python_progress_wrapper.h"
#include "../src/progress_bar.h"
// Python友好的进度条类
class PythonProgressBar : public ProgressBar {
private:
std::string description_;
unsigned long total_count_;
bool use_python_callback_;
public:
PythonProgressBar() : ProgressBar(), total_count_(0), use_python_callback_(false) {}
PythonProgressBar(unsigned long n_, const char* description_="", std::ostream& out_=std::cerr)
: ProgressBar(n_, description_, out_), description_(description_), total_count_(n_), use_python_callback_(false) {
// 检查是否有Python回调
if (ProgressCallbackManager::has_callback()) {
use_python_callback_ = true;
auto* callback = ProgressCallbackManager::get_callback();
if (callback) {
callback->set_description(description_);
callback->set_total(n_);
}
}
}
void SetFrequencyUpdate(unsigned long frequency_update_) {
if (!use_python_callback_) {
ProgressBar::SetFrequencyUpdate(frequency_update_);
}
// Python模式下使用固定更新频率
}
void SetStyle(const char* unit_bar_, const char* unit_space_) {
if (!use_python_callback_) {
ProgressBar::SetStyle(unit_bar_, unit_space_);
}
// Python模式下忽略样式设置
}
void Progressed(unsigned long idx_) {
if (use_python_callback_) {
auto* callback = ProgressCallbackManager::get_callback();
if (callback) {
callback->update_progress(idx_);
// 完成时调用finish
if (idx_ >= total_count_ - 1) {
callback->finish();
}
return;
}
}
// 回退到原有的进度条实现
ProgressBar::Progressed(idx_);
}
// 设置是否使用Python回调
void set_use_python_callback(bool use) {
use_python_callback_ = use;
}
};
// 替换原有的ProgressBar定义
#define ProgressBar PythonProgressBar
#endif // PYTHON_BINDING
#endif // PROGRESS_BAR_PYTHON_H