[CUDA] Recycle CUDA events (#2604)

* Make CudaEvent a CudaHandle

* Add caching for CudaEvent

* Make sure cuda events are destroyed at last

* Fix headers

* SharedEvent => AtomicEvent

* RawCudaEvent => CudaEventHandle, CudaEventWrapper => CopyableCudaEvent

* Remove unneeded asserts
This commit is contained in:
Cheng
2025-09-23 10:42:03 +09:00
committed by GitHub
parent 711a645807
commit ae438d05fa
6 changed files with 159 additions and 110 deletions

View File

@@ -15,8 +15,9 @@ bool is_available() {
} }
void new_stream(Stream s) { void new_stream(Stream s) {
// Force initalization of cuda, so cuda runtime get destroyed at last. // Force initalization of CUDA by creating an event, so the CUDA runtime and
cudaFree(nullptr); // our CUDA event pool get destroyed last.
cu::CudaEvent(cudaEventDefault);
// Ensure the static stream objects get created. // Ensure the static stream objects get created.
cu::get_command_encoder(s); cu::get_command_encoder(s);
} }

View File

@@ -3,10 +3,12 @@
#include "mlx/backend/cuda/allocator.h" #include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/event.h" #include "mlx/backend/cuda/event.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/event.h" #include "mlx/event.h"
#include "mlx/scheduler.h" #include "mlx/scheduler.h"
#include <map>
#include <vector>
#include <nvtx3/nvtx3.hpp> #include <nvtx3/nvtx3.hpp>
namespace mlx::core { namespace mlx::core {
@@ -17,104 +19,141 @@ namespace cu {
// CudaEvent implementations // CudaEvent implementations
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Cuda event managed with RAII. namespace {
class CudaEventHandle {
public: // Manage cached cudaEvent_t objects.
CudaEventHandle() { struct CudaEventPool {
CHECK_CUDA_ERROR(cudaEventCreateWithFlags( static CudaEventHandle create(int flags) {
&event_, cudaEventDisableTiming | cudaEventBlockingSync)); auto& cache = cache_for(flags);
if (cache.empty()) {
return CudaEventHandle(flags);
} else {
CudaEventHandle ret = std::move(cache.back());
cache.pop_back();
return ret;
}
} }
~CudaEventHandle() { static void release(CudaEventHandle event) {
CHECK_CUDA_ERROR(cudaEventDestroy(event_)); cache_for(event.flags).push_back(std::move(event));
} }
CudaEventHandle(const CudaEventHandle&) = delete; static std::vector<CudaEventHandle>& cache_for(int flags) {
CudaEventHandle& operator=(const CudaEventHandle&) = delete; static std::map<int, std::vector<CudaEventHandle>> cache;
return cache[flags];
operator cudaEvent_t() const {
return event_;
} }
private:
cudaEvent_t event_;
}; };
CudaEvent::CudaEvent() : event_(std::make_shared<CudaEventHandle>()) {} } // namespace
CudaEventHandle::CudaEventHandle(int flags) : flags(flags) {
CHECK_CUDA_ERROR(cudaEventCreateWithFlags(&handle_, flags));
assert(handle_ != nullptr);
}
CudaEvent::CudaEvent(int flags) : event_(CudaEventPool::create(flags)) {}
CudaEvent::~CudaEvent() {
CudaEventPool::release(std::move(event_));
}
void CudaEvent::wait() { void CudaEvent::wait() {
nvtx3::scoped_range r("cu::CudaEvent::wait"); nvtx3::scoped_range r("cu::CudaEvent::wait");
if (!recorded_) { cudaEventSynchronize(event_);
throw std::runtime_error("Should not wait on a CudaEvent before record.");
}
cudaEventSynchronize(*event_);
} }
void CudaEvent::wait(cudaStream_t stream) { void CudaEvent::wait(cudaStream_t stream) {
if (!recorded_) { cudaStreamWaitEvent(stream, event_);
throw std::runtime_error("Should not wait on a CudaEvent before record.");
}
cudaStreamWaitEvent(stream, *event_);
}
void CudaEvent::wait(Stream s) {
if (s.device == mlx::core::Device::cpu) {
scheduler::enqueue(s, [*this]() mutable { wait(); });
} else {
auto& enc = cu::get_command_encoder(s);
enc.commit();
wait(enc.stream());
}
} }
void CudaEvent::record(cudaStream_t stream) { void CudaEvent::record(cudaStream_t stream) {
cudaEventRecord(*event_, stream); cudaEventRecord(event_, stream);
recorded_ = true;
}
void CudaEvent::record(Stream s) {
if (s.device == mlx::core::Device::cpu) {
throw std::runtime_error("CudaEvent can not wait on cpu stream.");
} else {
auto& enc = cu::get_command_encoder(s);
enc.commit();
record(enc.stream());
}
} }
bool CudaEvent::completed() const { bool CudaEvent::completed() const {
return cudaEventQuery(*event_) == cudaSuccess; return cudaEventQuery(event_) == cudaSuccess;
} }
// Wraps CudaEvent with a few features:
// 1. The class can be copied.
// 2. Make wait/record work with CPU streams.
// 3. Add checks for waiting on un-recorded event.
class CopyableCudaEvent {
public:
CopyableCudaEvent()
: event_(std::make_shared<CudaEvent>(
cudaEventDisableTiming | cudaEventBlockingSync)) {}
void wait() {
event_->wait();
}
void wait(Stream s) {
if (s.device == mlx::core::Device::cpu) {
scheduler::enqueue(s, [*this]() mutable {
check_recorded();
event_->wait();
});
} else {
check_recorded();
auto& encoder = cu::get_command_encoder(s);
encoder.commit();
event_->wait(encoder.stream());
}
}
void record(Stream s) {
if (s.device == mlx::core::Device::cpu) {
throw std::runtime_error("CudaEvent can not wait on CPU stream.");
} else {
auto& encoder = cu::get_command_encoder(s);
encoder.commit();
event_->record(encoder.stream());
recorded_ = true;
}
}
bool is_signaled() const {
return recorded_ && event_->completed();
}
private:
void check_recorded() const {
if (!recorded_) {
throw std::runtime_error(
"Should not wait on a CudaEvent before recording.");
}
}
std::shared_ptr<CudaEvent> event_;
bool recorded_{false};
};
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// SharedEvent implementations // AtomicEvent implementations
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
__host__ __device__ void event_wait(SharedEvent::Atomic* ac, uint64_t value) { __host__ __device__ void event_wait(AtomicEvent::Atomic* ac, uint64_t value) {
uint64_t current; uint64_t current;
while ((current = ac->load()) < value) { while ((current = ac->load()) < value) {
ac->wait(current); ac->wait(current);
} }
} }
__host__ __device__ void event_signal(SharedEvent::Atomic* ac, uint64_t value) { __host__ __device__ void event_signal(AtomicEvent::Atomic* ac, uint64_t value) {
ac->store(value); ac->store(value);
ac->notify_all(); ac->notify_all();
} }
__global__ void event_wait_kernel(SharedEvent::Atomic* ac, uint64_t value) { __global__ void event_wait_kernel(AtomicEvent::Atomic* ac, uint64_t value) {
event_wait(ac, value); event_wait(ac, value);
} }
__global__ void event_signal_kernel(SharedEvent::Atomic* ac, uint64_t value) { __global__ void event_signal_kernel(AtomicEvent::Atomic* ac, uint64_t value) {
event_signal(ac, value); event_signal(ac, value);
} }
SharedEvent::Atomic* to_atomic(std::shared_ptr<Buffer> buf) { AtomicEvent::AtomicEvent() {
return static_cast<SharedEvent::Atomic*>(buf->raw_ptr());
}
SharedEvent::SharedEvent() {
buf_ = std::shared_ptr<Buffer>( buf_ = std::shared_ptr<Buffer>(
new Buffer{allocator().malloc(sizeof(Atomic))}, [](Buffer* ptr) { new Buffer{allocator().malloc(sizeof(Atomic))}, [](Buffer* ptr) {
allocator().free(*ptr); allocator().free(*ptr);
@@ -123,17 +162,17 @@ SharedEvent::SharedEvent() {
*static_cast<uint64_t*>(buf_->raw_ptr()) = 0; *static_cast<uint64_t*>(buf_->raw_ptr()) = 0;
} }
void SharedEvent::wait(uint64_t value) { void AtomicEvent::wait(uint64_t value) {
nvtx3::scoped_range r("cu::SharedEvent::wait"); nvtx3::scoped_range r("cu::AtomicEvent::wait");
event_wait(to_atomic(buf_), value); event_wait(atomic(), value);
} }
void SharedEvent::wait(cudaStream_t stream, uint64_t value) { void AtomicEvent::wait(cudaStream_t stream, uint64_t value) {
event_wait_kernel<<<1, 1, 0, stream>>>(to_atomic(buf_), value); event_wait_kernel<<<1, 1, 0, stream>>>(atomic(), value);
} }
void SharedEvent::wait(Stream s, uint64_t value) { void AtomicEvent::wait(Stream s, uint64_t value) {
nvtx3::scoped_range r("cu::SharedEvent::wait(s)"); nvtx3::scoped_range r("cu::AtomicEvent::wait(s)");
if (s.device == mlx::core::Device::cpu) { if (s.device == mlx::core::Device::cpu) {
scheduler::enqueue(s, [*this, value]() mutable { wait(value); }); scheduler::enqueue(s, [*this, value]() mutable { wait(value); });
} else { } else {
@@ -144,17 +183,17 @@ void SharedEvent::wait(Stream s, uint64_t value) {
} }
} }
void SharedEvent::signal(uint64_t value) { void AtomicEvent::signal(uint64_t value) {
nvtx3::scoped_range r("cu::SharedEvent::signal"); nvtx3::scoped_range r("cu::AtomicEvent::signal");
event_signal(to_atomic(buf_), value); event_signal(atomic(), value);
} }
void SharedEvent::signal(cudaStream_t stream, uint64_t value) { void AtomicEvent::signal(cudaStream_t stream, uint64_t value) {
event_signal_kernel<<<1, 1, 0, stream>>>(to_atomic(buf_), value); event_signal_kernel<<<1, 1, 0, stream>>>(atomic(), value);
} }
void SharedEvent::signal(Stream s, uint64_t value) { void AtomicEvent::signal(Stream s, uint64_t value) {
nvtx3::scoped_range r("cu::SharedEvent::signal(s)"); nvtx3::scoped_range r("cu::AtomicEvent::signal(s)");
if (s.device == mlx::core::Device::cpu) { if (s.device == mlx::core::Device::cpu) {
// Signal through a GPU stream so the atomic is updated in GPU - updating // Signal through a GPU stream so the atomic is updated in GPU - updating
// the atomic in CPU sometimes does not get GPU notified. // the atomic in CPU sometimes does not get GPU notified.
@@ -168,14 +207,14 @@ void SharedEvent::signal(Stream s, uint64_t value) {
} }
} }
bool SharedEvent::is_signaled(uint64_t value) const { bool AtomicEvent::is_signaled(uint64_t value) const {
nvtx3::scoped_range r("cu::SharedEvent::is_signaled"); nvtx3::scoped_range r("cu::AtomicEvent::is_signaled");
return to_atomic(buf_)->load() >= value; return atomic()->load() >= value;
} }
uint64_t SharedEvent::value() const { uint64_t AtomicEvent::value() const {
nvtx3::scoped_range r("cu::SharedEvent::value"); nvtx3::scoped_range r("cu::AtomicEvent::value");
return to_atomic(buf_)->load(); return atomic()->load();
} }
} // namespace cu } // namespace cu
@@ -188,14 +227,14 @@ namespace {
struct EventImpl { struct EventImpl {
// CudaEvent is preferred when possible because it is fast, however we have // CudaEvent is preferred when possible because it is fast, however we have
// to fallback to SharedEvent in following cases: // to fallback to AtomicEvent in following cases:
// 1. the event is used to wait/signal a cpu stream; // 1. the event is used to wait/signal a cpu stream;
// 2. signal value other than 1 has been specified. // 2. signal value other than 1 has been specified.
std::unique_ptr<cu::CudaEvent> cuda; std::unique_ptr<cu::CopyableCudaEvent> cuda;
std::unique_ptr<cu::SharedEvent> shared; std::unique_ptr<cu::AtomicEvent> atomic;
bool is_created() const { bool is_created() const {
return cuda || shared; return cuda || atomic;
} }
void ensure_created(Stream s, uint64_t signal_value) { void ensure_created(Stream s, uint64_t signal_value) {
@@ -203,10 +242,10 @@ struct EventImpl {
return; return;
} }
if (s.device == mlx::core::Device::cpu || signal_value > 1) { if (s.device == mlx::core::Device::cpu || signal_value > 1) {
nvtx3::mark("Using slow SharedEvent"); nvtx3::mark("Using slow AtomicEvent");
shared = std::make_unique<cu::SharedEvent>(); atomic = std::make_unique<cu::AtomicEvent>();
} else { } else {
cuda = std::make_unique<cu::CudaEvent>(); cuda = std::make_unique<cu::CopyableCudaEvent>();
} }
} }
}; };
@@ -225,7 +264,7 @@ void Event::wait() {
assert(value() == 1); assert(value() == 1);
event->cuda->wait(); event->cuda->wait();
} else { } else {
event->shared->wait(value()); event->atomic->wait(value());
} }
} }
@@ -236,7 +275,7 @@ void Event::wait(Stream s) {
assert(value() == 1); assert(value() == 1);
event->cuda->wait(s); event->cuda->wait(s);
} else { } else {
event->shared->wait(s, value()); event->atomic->wait(s, value());
} }
} }
@@ -247,7 +286,7 @@ void Event::signal(Stream s) {
assert(value() == 1); assert(value() == 1);
event->cuda->record(s); event->cuda->record(s);
} else { } else {
event->shared->signal(s, value()); event->atomic->signal(s, value());
} }
} }
@@ -258,9 +297,9 @@ bool Event::is_signaled() const {
} }
if (event->cuda) { if (event->cuda) {
assert(value() == 1); assert(value() == 1);
return event->cuda->recorded() && event->cuda->completed(); return event->cuda->is_signaled();
} else { } else {
return event->shared->is_signaled(value()); return event->atomic->is_signaled(value());
} }
} }

View File

@@ -3,49 +3,54 @@
#pragma once #pragma once
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/backend/cuda/utils.h"
#include "mlx/stream.h" #include "mlx/stream.h"
#include <memory>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda/atomic> #include <cuda/atomic>
#include <memory>
namespace mlx::core::cu { namespace mlx::core::cu {
class CudaEventHandle; // RAII-managed move-only wrapper of cudaEvent_t.
struct CudaEventHandle : public CudaHandle<cudaEvent_t, cudaEventDestroy> {
CudaEventHandle(int flags);
int flags;
};
// Wrapper of native cuda event. It can synchronize between GPU streams, or wait // Wrapper of native cuda event. It can synchronize between GPU streams, or wait
// on GPU stream in CPU stream, but can not wait on CPU stream. // on GPU stream in CPU stream, but can not wait on CPU stream.
class CudaEvent { class CudaEvent {
public: public:
CudaEvent(); explicit CudaEvent(int flags);
~CudaEvent();
CudaEvent(CudaEvent&&) = default;
CudaEvent& operator=(CudaEvent&&) = default;
CudaEvent(const CudaEvent&) = delete;
CudaEvent& operator=(const CudaEvent&) = delete;
void wait(); void wait();
void wait(cudaStream_t stream); void wait(cudaStream_t stream);
void wait(Stream s);
void record(cudaStream_t stream); void record(cudaStream_t stream);
void record(Stream s);
// Return whether the recorded kernels have completed. Note that this method // Return whether the recorded kernels have completed. Note that this method
// returns true if record() has not been called. // returns true if record() has not been called.
bool completed() const; bool completed() const;
bool recorded() const {
return recorded_;
}
private: private:
bool recorded_{false}; CudaEventHandle event_;
std::shared_ptr<CudaEventHandle> event_;
}; };
// Event that can synchronize between CPU and GPU. It is much slower than // Event that can synchronize between CPU and GPU. It is much slower than
// CudaEvent so the latter should always be preferred when possible. // CudaEvent so the latter should always be preferred when possible.
class SharedEvent { class AtomicEvent {
public: public:
using Atomic = cuda::atomic<uint64_t>; using Atomic = cuda::atomic<uint64_t>;
SharedEvent(); AtomicEvent();
void wait(uint64_t value); void wait(uint64_t value);
void wait(cudaStream_t stream, uint64_t value); void wait(cudaStream_t stream, uint64_t value);
@@ -57,7 +62,11 @@ class SharedEvent {
uint64_t value() const; uint64_t value() const;
private: private:
std::shared_ptr<mlx::core::allocator::Buffer> buf_; Atomic* atomic() const {
return static_cast<AtomicEvent::Atomic*>(buf_->raw_ptr());
}
std::shared_ptr<allocator::Buffer> buf_;
}; };
} // namespace mlx::core::cu } // namespace mlx::core::cu

View File

@@ -7,7 +7,7 @@ namespace mlx::core {
struct FenceImpl { struct FenceImpl {
uint32_t count; uint32_t count;
cu::SharedEvent event; cu::AtomicEvent event;
}; };
Fence::Fence(Stream s) { Fence::Fence(Stream s) {

View File

@@ -7,6 +7,7 @@ namespace mlx::core::cu {
Worker::Worker() Worker::Worker()
: signal_stream_(device(mlx::core::Device::gpu)), : signal_stream_(device(mlx::core::Device::gpu)),
signal_event_(cudaEventDisableTiming | cudaEventBlockingSync),
worker_(&Worker::thread_fn, this) {} worker_(&Worker::thread_fn, this) {}
Worker::~Worker() { Worker::~Worker() {

View File

@@ -3,7 +3,6 @@
#pragma once #pragma once
#include "mlx/backend/cuda/event.h" #include "mlx/backend/cuda/event.h"
#include "mlx/backend/cuda/utils.h"
#include <condition_variable> #include <condition_variable>
#include <functional> #include <functional>