mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Don't use shared event in worker
This commit is contained in:
@@ -30,6 +30,9 @@ SmallSizePool::SmallSizePool() {
|
|||||||
reinterpret_cast<char*>(buffer_) + small_pool_size);
|
reinterpret_cast<char*>(buffer_) + small_pool_size);
|
||||||
next_free_ = reinterpret_cast<Block*>(buffer_);
|
next_free_ = reinterpret_cast<Block*>(buffer_);
|
||||||
|
|
||||||
|
CHECK_CUDA_ERROR(
|
||||||
|
cudaMemAdvise(buffer_, small_pool_size, cudaMemAdviseSetReadMostly, 0));
|
||||||
|
|
||||||
auto num_blocks = small_pool_size / small_block_size;
|
auto num_blocks = small_pool_size / small_block_size;
|
||||||
auto curr = next_free_;
|
auto curr = next_free_;
|
||||||
for (size_t i = 0; i < num_blocks - 1; ++i) {
|
for (size_t i = 0; i < num_blocks - 1; ++i) {
|
||||||
|
|||||||
@@ -306,7 +306,6 @@ void CommandEncoder::commit() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Put completion handlers in a batch.
|
// Put completion handlers in a batch.
|
||||||
worker_.end_batch();
|
|
||||||
worker_.commit(stream_);
|
worker_.commit(stream_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -11,10 +11,10 @@ Worker::Worker()
|
|||||||
|
|
||||||
Worker::~Worker() {
|
Worker::~Worker() {
|
||||||
{
|
{
|
||||||
std::lock_guard lock(worker_mutex_);
|
std::lock_guard lock(mtx_);
|
||||||
stop_ = true;
|
stop_ = true;
|
||||||
}
|
}
|
||||||
worker_event_.signal(batch_ + 1);
|
cond_.notify_one();
|
||||||
worker_.join();
|
worker_.join();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -22,35 +22,45 @@ void Worker::add_task(std::function<void()> task) {
|
|||||||
pending_tasks_.push_back(std::move(task));
|
pending_tasks_.push_back(std::move(task));
|
||||||
}
|
}
|
||||||
|
|
||||||
void Worker::end_batch() {
|
void signal_worker(void* data) {
|
||||||
batch_++;
|
auto w = static_cast<Worker*>(data);
|
||||||
{
|
w->signal_();
|
||||||
std::lock_guard lock(worker_mutex_);
|
|
||||||
worker_tasks_[batch_] = std::move(pending_tasks_);
|
|
||||||
}
|
}
|
||||||
uncommited_batches_++;
|
|
||||||
|
void Worker::signal_() {
|
||||||
|
{
|
||||||
|
std::lock_guard lock(mtx_);
|
||||||
|
signaled_batch_++;
|
||||||
|
}
|
||||||
|
cond_.notify_one();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Worker::commit(cudaStream_t stream) {
|
void Worker::commit(cudaStream_t stream) {
|
||||||
if (uncommited_batches_ == 0) {
|
// Move pending tasks into tasks
|
||||||
|
if (pending_tasks_.empty()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
uncommited_batches_ = 0;
|
{
|
||||||
// Signal the |worker_event_| in |signal_stream_| after the kernels in
|
std::lock_guard lock(mtx_);
|
||||||
// |stream_| finish running.
|
// Move pending tasks into ready tasks
|
||||||
|
worker_tasks_[++committed_batch_] = std::move(pending_tasks_);
|
||||||
|
}
|
||||||
signal_event_.record(stream);
|
signal_event_.record(stream);
|
||||||
signal_event_.wait(signal_stream_);
|
signal_event_.wait(signal_stream_);
|
||||||
worker_event_.signal(signal_stream_, batch_);
|
cudaLaunchHostFunc(signal_stream_, signal_worker, this);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Worker::thread_fn() {
|
void Worker::thread_fn() {
|
||||||
while (!stop_) {
|
while (!stop_) {
|
||||||
uint64_t batch = worker_event_.value();
|
uint64_t current_batch = 0;
|
||||||
Tasks tasks;
|
Tasks tasks;
|
||||||
{
|
{
|
||||||
std::lock_guard lock(worker_mutex_);
|
std::unique_lock<std::mutex> lk(mtx_);
|
||||||
// Move tasks in signaled batches.
|
cond_.wait(lk, [this, ¤t_batch] {
|
||||||
auto end = worker_tasks_.upper_bound(batch);
|
return this->signaled_batch_ > current_batch || this->stop_;
|
||||||
|
});
|
||||||
|
current_batch = signaled_batch_;
|
||||||
|
auto end = worker_tasks_.upper_bound(current_batch);
|
||||||
for (auto it = worker_tasks_.begin(); it != end; ++it) {
|
for (auto it = worker_tasks_.begin(); it != end; ++it) {
|
||||||
if (tasks.empty()) {
|
if (tasks.empty()) {
|
||||||
tasks = std::move(it->second);
|
tasks = std::move(it->second);
|
||||||
@@ -66,7 +76,6 @@ void Worker::thread_fn() {
|
|||||||
auto task = std::move(tasks[i]);
|
auto task = std::move(tasks[i]);
|
||||||
task();
|
task();
|
||||||
}
|
}
|
||||||
worker_event_.wait(batch + 1);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
#include "mlx/backend/cuda/event.h"
|
#include "mlx/backend/cuda/event.h"
|
||||||
#include "mlx/backend/cuda/utils.h"
|
#include "mlx/backend/cuda/utils.h"
|
||||||
|
|
||||||
|
#include <condition_variable>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
@@ -12,6 +13,8 @@
|
|||||||
|
|
||||||
namespace mlx::core::cu {
|
namespace mlx::core::cu {
|
||||||
|
|
||||||
|
void signal_worker(void* data);
|
||||||
|
|
||||||
// Run tasks in worker thread, synchronized with cuda stream.
|
// Run tasks in worker thread, synchronized with cuda stream.
|
||||||
class Worker {
|
class Worker {
|
||||||
public:
|
public:
|
||||||
@@ -24,32 +27,25 @@ class Worker {
|
|||||||
// Add a pending |task| that will run when consumed or commited.
|
// Add a pending |task| that will run when consumed or commited.
|
||||||
void add_task(std::function<void()> task);
|
void add_task(std::function<void()> task);
|
||||||
|
|
||||||
// Put pending tasks in a batch.
|
|
||||||
void end_batch();
|
|
||||||
|
|
||||||
// Inform worker thread to run current batches after kernels in |stream|
|
// Inform worker thread to run current batches after kernels in |stream|
|
||||||
// finish running.
|
// finish running.
|
||||||
void commit(cudaStream_t stream);
|
void commit(cudaStream_t stream);
|
||||||
|
|
||||||
// Return how many batches have been added but not committed yet.
|
|
||||||
size_t uncommited_batches() const {
|
|
||||||
return uncommited_batches_;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void thread_fn();
|
friend void signal_worker(void*);
|
||||||
|
|
||||||
uint64_t batch_{0};
|
void signal_();
|
||||||
size_t uncommited_batches_{0};
|
void thread_fn();
|
||||||
|
std::mutex mtx_;
|
||||||
|
std::condition_variable cond_;
|
||||||
|
|
||||||
|
uint64_t committed_batch_{0};
|
||||||
|
uint64_t signaled_batch_{0};
|
||||||
|
|
||||||
// Cuda stream and event for signaling kernel completion.
|
// Cuda stream and event for signaling kernel completion.
|
||||||
CudaStream signal_stream_;
|
CudaStream signal_stream_;
|
||||||
CudaEvent signal_event_;
|
CudaEvent signal_event_;
|
||||||
|
|
||||||
// Worker thread.
|
|
||||||
SharedEvent worker_event_;
|
|
||||||
std::thread worker_;
|
|
||||||
std::mutex worker_mutex_;
|
|
||||||
bool stop_{false};
|
bool stop_{false};
|
||||||
|
|
||||||
// Tasks are put in |pending_tasks_| first, and then moved to
|
// Tasks are put in |pending_tasks_| first, and then moved to
|
||||||
@@ -57,6 +53,7 @@ class Worker {
|
|||||||
using Tasks = std::vector<std::function<void()>>;
|
using Tasks = std::vector<std::function<void()>>;
|
||||||
Tasks pending_tasks_;
|
Tasks pending_tasks_;
|
||||||
std::map<uint64_t, Tasks> worker_tasks_;
|
std::map<uint64_t, Tasks> worker_tasks_;
|
||||||
|
std::thread worker_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace mlx::core::cu
|
} // namespace mlx::core::cu
|
||||||
|
|||||||
Reference in New Issue
Block a user