mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
comment
This commit is contained in:
@@ -22,18 +22,15 @@ void Worker::add_task(std::function<void()> task) {
|
|||||||
pending_tasks_.push_back(std::move(task));
|
pending_tasks_.push_back(std::move(task));
|
||||||
}
|
}
|
||||||
|
|
||||||
void signal_worker(void* data) {
|
void Worker::signal(void* data) {
|
||||||
auto w = static_cast<Worker*>(data);
|
auto w = static_cast<Worker*>(data);
|
||||||
w->signal_();
|
{
|
||||||
|
std::lock_guard lock(w->mtx_);
|
||||||
|
w->signaled_batch_++;
|
||||||
|
}
|
||||||
|
w->cond_.notify_one();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Worker::signal_() {
|
|
||||||
{
|
|
||||||
std::lock_guard lock(mtx_);
|
|
||||||
signaled_batch_++;
|
|
||||||
}
|
|
||||||
cond_.notify_one();
|
|
||||||
}
|
|
||||||
|
|
||||||
void Worker::commit(cudaStream_t stream) {
|
void Worker::commit(cudaStream_t stream) {
|
||||||
// Move pending tasks into tasks
|
// Move pending tasks into tasks
|
||||||
@@ -47,7 +44,7 @@ void Worker::commit(cudaStream_t stream) {
|
|||||||
}
|
}
|
||||||
signal_event_.record(stream);
|
signal_event_.record(stream);
|
||||||
signal_event_.wait(signal_stream_);
|
signal_event_.wait(signal_stream_);
|
||||||
cudaLaunchHostFunc(signal_stream_, signal_worker, this);
|
cudaLaunchHostFunc(signal_stream_, signal, this);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Worker::thread_fn() {
|
void Worker::thread_fn() {
|
||||||
|
|||||||
@@ -13,8 +13,6 @@
|
|||||||
|
|
||||||
namespace mlx::core::cu {
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
void signal_worker(void* data);
|
|
||||||
|
|
||||||
// Run tasks in worker thread, synchronized with cuda stream.
|
// Run tasks in worker thread, synchronized with cuda stream.
|
||||||
class Worker {
|
class Worker {
|
||||||
public:
|
public:
|
||||||
@@ -32,9 +30,8 @@ class Worker {
|
|||||||
void commit(cudaStream_t stream);
|
void commit(cudaStream_t stream);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
friend void signal_worker(void*);
|
static void signal(void*);
|
||||||
|
|
||||||
void signal_();
|
|
||||||
void thread_fn();
|
void thread_fn();
|
||||||
std::mutex mtx_;
|
std::mutex mtx_;
|
||||||
std::condition_variable cond_;
|
std::condition_variable cond_;
|
||||||
|
|||||||
Reference in New Issue
Block a user