diff --git a/mlx/backend/cuda/eval.cpp b/mlx/backend/cuda/eval.cpp index 0c2690dcf..b91fafa22 100644 --- a/mlx/backend/cuda/eval.cpp +++ b/mlx/backend/cuda/eval.cpp @@ -15,8 +15,9 @@ bool is_available() { } void new_stream(Stream s) { - // Force initalization of cuda, so cuda runtime get destroyed at last. - cudaFree(nullptr); + // Force initalization of CUDA by creating an event, so the CUDA runtime and + // our CUDA event pool get destroyed last. + cu::CudaEvent(cudaEventDefault); // Ensure the static stream objects get created. cu::get_command_encoder(s); } diff --git a/mlx/backend/cuda/event.cu b/mlx/backend/cuda/event.cu index c9c6b3659..2d9c96604 100644 --- a/mlx/backend/cuda/event.cu +++ b/mlx/backend/cuda/event.cu @@ -3,10 +3,12 @@ #include "mlx/backend/cuda/allocator.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/event.h" -#include "mlx/backend/cuda/utils.h" #include "mlx/event.h" #include "mlx/scheduler.h" +#include +#include + #include namespace mlx::core { @@ -17,104 +19,141 @@ namespace cu { // CudaEvent implementations /////////////////////////////////////////////////////////////////////////////// -// Cuda event managed with RAII. -class CudaEventHandle { - public: - CudaEventHandle() { - CHECK_CUDA_ERROR(cudaEventCreateWithFlags( - &event_, cudaEventDisableTiming | cudaEventBlockingSync)); +namespace { + +// Manage cached cudaEvent_t objects. +struct CudaEventPool { + static CudaEventHandle create(int flags) { + auto& cache = cache_for(flags); + if (cache.empty()) { + return CudaEventHandle(flags); + } else { + CudaEventHandle ret = std::move(cache.back()); + cache.pop_back(); + return ret; + } } - ~CudaEventHandle() { - CHECK_CUDA_ERROR(cudaEventDestroy(event_)); + static void release(CudaEventHandle event) { + cache_for(event.flags).push_back(std::move(event)); } - CudaEventHandle(const CudaEventHandle&) = delete; - CudaEventHandle& operator=(const CudaEventHandle&) = delete; - - operator cudaEvent_t() const { - return event_; + static std::vector& cache_for(int flags) { + static std::map> cache; + return cache[flags]; } - - private: - cudaEvent_t event_; }; -CudaEvent::CudaEvent() : event_(std::make_shared()) {} +} // namespace + +CudaEventHandle::CudaEventHandle(int flags) : flags(flags) { + CHECK_CUDA_ERROR(cudaEventCreateWithFlags(&handle_, flags)); + assert(handle_ != nullptr); +} + +CudaEvent::CudaEvent(int flags) : event_(CudaEventPool::create(flags)) {} + +CudaEvent::~CudaEvent() { + CudaEventPool::release(std::move(event_)); +} void CudaEvent::wait() { nvtx3::scoped_range r("cu::CudaEvent::wait"); - if (!recorded_) { - throw std::runtime_error("Should not wait on a CudaEvent before record."); - } - cudaEventSynchronize(*event_); + cudaEventSynchronize(event_); } void CudaEvent::wait(cudaStream_t stream) { - if (!recorded_) { - throw std::runtime_error("Should not wait on a CudaEvent before record."); - } - cudaStreamWaitEvent(stream, *event_); -} - -void CudaEvent::wait(Stream s) { - if (s.device == mlx::core::Device::cpu) { - scheduler::enqueue(s, [*this]() mutable { wait(); }); - } else { - auto& enc = cu::get_command_encoder(s); - enc.commit(); - wait(enc.stream()); - } + cudaStreamWaitEvent(stream, event_); } void CudaEvent::record(cudaStream_t stream) { - cudaEventRecord(*event_, stream); - recorded_ = true; -} - -void CudaEvent::record(Stream s) { - if (s.device == mlx::core::Device::cpu) { - throw std::runtime_error("CudaEvent can not wait on cpu stream."); - } else { - auto& enc = cu::get_command_encoder(s); - enc.commit(); - record(enc.stream()); - } + cudaEventRecord(event_, stream); } bool CudaEvent::completed() const { - return cudaEventQuery(*event_) == cudaSuccess; + return cudaEventQuery(event_) == cudaSuccess; } +// Wraps CudaEvent with a few features: +// 1. The class can be copied. +// 2. Make wait/record work with CPU streams. +// 3. Add checks for waiting on un-recorded event. +class CopyableCudaEvent { + public: + CopyableCudaEvent() + : event_(std::make_shared( + cudaEventDisableTiming | cudaEventBlockingSync)) {} + + void wait() { + event_->wait(); + } + + void wait(Stream s) { + if (s.device == mlx::core::Device::cpu) { + scheduler::enqueue(s, [*this]() mutable { + check_recorded(); + event_->wait(); + }); + } else { + check_recorded(); + auto& encoder = cu::get_command_encoder(s); + encoder.commit(); + event_->wait(encoder.stream()); + } + } + + void record(Stream s) { + if (s.device == mlx::core::Device::cpu) { + throw std::runtime_error("CudaEvent can not wait on CPU stream."); + } else { + auto& encoder = cu::get_command_encoder(s); + encoder.commit(); + event_->record(encoder.stream()); + recorded_ = true; + } + } + + bool is_signaled() const { + return recorded_ && event_->completed(); + } + + private: + void check_recorded() const { + if (!recorded_) { + throw std::runtime_error( + "Should not wait on a CudaEvent before recording."); + } + } + + std::shared_ptr event_; + bool recorded_{false}; +}; + /////////////////////////////////////////////////////////////////////////////// -// SharedEvent implementations +// AtomicEvent implementations /////////////////////////////////////////////////////////////////////////////// -__host__ __device__ void event_wait(SharedEvent::Atomic* ac, uint64_t value) { +__host__ __device__ void event_wait(AtomicEvent::Atomic* ac, uint64_t value) { uint64_t current; while ((current = ac->load()) < value) { ac->wait(current); } } -__host__ __device__ void event_signal(SharedEvent::Atomic* ac, uint64_t value) { +__host__ __device__ void event_signal(AtomicEvent::Atomic* ac, uint64_t value) { ac->store(value); ac->notify_all(); } -__global__ void event_wait_kernel(SharedEvent::Atomic* ac, uint64_t value) { +__global__ void event_wait_kernel(AtomicEvent::Atomic* ac, uint64_t value) { event_wait(ac, value); } -__global__ void event_signal_kernel(SharedEvent::Atomic* ac, uint64_t value) { +__global__ void event_signal_kernel(AtomicEvent::Atomic* ac, uint64_t value) { event_signal(ac, value); } -SharedEvent::Atomic* to_atomic(std::shared_ptr buf) { - return static_cast(buf->raw_ptr()); -} - -SharedEvent::SharedEvent() { +AtomicEvent::AtomicEvent() { buf_ = std::shared_ptr( new Buffer{allocator().malloc(sizeof(Atomic))}, [](Buffer* ptr) { allocator().free(*ptr); @@ -123,17 +162,17 @@ SharedEvent::SharedEvent() { *static_cast(buf_->raw_ptr()) = 0; } -void SharedEvent::wait(uint64_t value) { - nvtx3::scoped_range r("cu::SharedEvent::wait"); - event_wait(to_atomic(buf_), value); +void AtomicEvent::wait(uint64_t value) { + nvtx3::scoped_range r("cu::AtomicEvent::wait"); + event_wait(atomic(), value); } -void SharedEvent::wait(cudaStream_t stream, uint64_t value) { - event_wait_kernel<<<1, 1, 0, stream>>>(to_atomic(buf_), value); +void AtomicEvent::wait(cudaStream_t stream, uint64_t value) { + event_wait_kernel<<<1, 1, 0, stream>>>(atomic(), value); } -void SharedEvent::wait(Stream s, uint64_t value) { - nvtx3::scoped_range r("cu::SharedEvent::wait(s)"); +void AtomicEvent::wait(Stream s, uint64_t value) { + nvtx3::scoped_range r("cu::AtomicEvent::wait(s)"); if (s.device == mlx::core::Device::cpu) { scheduler::enqueue(s, [*this, value]() mutable { wait(value); }); } else { @@ -144,17 +183,17 @@ void SharedEvent::wait(Stream s, uint64_t value) { } } -void SharedEvent::signal(uint64_t value) { - nvtx3::scoped_range r("cu::SharedEvent::signal"); - event_signal(to_atomic(buf_), value); +void AtomicEvent::signal(uint64_t value) { + nvtx3::scoped_range r("cu::AtomicEvent::signal"); + event_signal(atomic(), value); } -void SharedEvent::signal(cudaStream_t stream, uint64_t value) { - event_signal_kernel<<<1, 1, 0, stream>>>(to_atomic(buf_), value); +void AtomicEvent::signal(cudaStream_t stream, uint64_t value) { + event_signal_kernel<<<1, 1, 0, stream>>>(atomic(), value); } -void SharedEvent::signal(Stream s, uint64_t value) { - nvtx3::scoped_range r("cu::SharedEvent::signal(s)"); +void AtomicEvent::signal(Stream s, uint64_t value) { + nvtx3::scoped_range r("cu::AtomicEvent::signal(s)"); if (s.device == mlx::core::Device::cpu) { // Signal through a GPU stream so the atomic is updated in GPU - updating // the atomic in CPU sometimes does not get GPU notified. @@ -168,14 +207,14 @@ void SharedEvent::signal(Stream s, uint64_t value) { } } -bool SharedEvent::is_signaled(uint64_t value) const { - nvtx3::scoped_range r("cu::SharedEvent::is_signaled"); - return to_atomic(buf_)->load() >= value; +bool AtomicEvent::is_signaled(uint64_t value) const { + nvtx3::scoped_range r("cu::AtomicEvent::is_signaled"); + return atomic()->load() >= value; } -uint64_t SharedEvent::value() const { - nvtx3::scoped_range r("cu::SharedEvent::value"); - return to_atomic(buf_)->load(); +uint64_t AtomicEvent::value() const { + nvtx3::scoped_range r("cu::AtomicEvent::value"); + return atomic()->load(); } } // namespace cu @@ -188,14 +227,14 @@ namespace { struct EventImpl { // CudaEvent is preferred when possible because it is fast, however we have - // to fallback to SharedEvent in following cases: + // to fallback to AtomicEvent in following cases: // 1. the event is used to wait/signal a cpu stream; // 2. signal value other than 1 has been specified. - std::unique_ptr cuda; - std::unique_ptr shared; + std::unique_ptr cuda; + std::unique_ptr atomic; bool is_created() const { - return cuda || shared; + return cuda || atomic; } void ensure_created(Stream s, uint64_t signal_value) { @@ -203,10 +242,10 @@ struct EventImpl { return; } if (s.device == mlx::core::Device::cpu || signal_value > 1) { - nvtx3::mark("Using slow SharedEvent"); - shared = std::make_unique(); + nvtx3::mark("Using slow AtomicEvent"); + atomic = std::make_unique(); } else { - cuda = std::make_unique(); + cuda = std::make_unique(); } } }; @@ -225,7 +264,7 @@ void Event::wait() { assert(value() == 1); event->cuda->wait(); } else { - event->shared->wait(value()); + event->atomic->wait(value()); } } @@ -236,7 +275,7 @@ void Event::wait(Stream s) { assert(value() == 1); event->cuda->wait(s); } else { - event->shared->wait(s, value()); + event->atomic->wait(s, value()); } } @@ -247,7 +286,7 @@ void Event::signal(Stream s) { assert(value() == 1); event->cuda->record(s); } else { - event->shared->signal(s, value()); + event->atomic->signal(s, value()); } } @@ -258,9 +297,9 @@ bool Event::is_signaled() const { } if (event->cuda) { assert(value() == 1); - return event->cuda->recorded() && event->cuda->completed(); + return event->cuda->is_signaled(); } else { - return event->shared->is_signaled(value()); + return event->atomic->is_signaled(value()); } } diff --git a/mlx/backend/cuda/event.h b/mlx/backend/cuda/event.h index 3ef9786f3..7a9ea66f6 100644 --- a/mlx/backend/cuda/event.h +++ b/mlx/backend/cuda/event.h @@ -3,49 +3,54 @@ #pragma once #include "mlx/allocator.h" +#include "mlx/backend/cuda/utils.h" #include "mlx/stream.h" +#include + #include #include -#include - namespace mlx::core::cu { -class CudaEventHandle; +// RAII-managed move-only wrapper of cudaEvent_t. +struct CudaEventHandle : public CudaHandle { + CudaEventHandle(int flags); + int flags; +}; // Wrapper of native cuda event. It can synchronize between GPU streams, or wait // on GPU stream in CPU stream, but can not wait on CPU stream. class CudaEvent { public: - CudaEvent(); + explicit CudaEvent(int flags); + ~CudaEvent(); + + CudaEvent(CudaEvent&&) = default; + CudaEvent& operator=(CudaEvent&&) = default; + + CudaEvent(const CudaEvent&) = delete; + CudaEvent& operator=(const CudaEvent&) = delete; void wait(); void wait(cudaStream_t stream); - void wait(Stream s); void record(cudaStream_t stream); - void record(Stream s); // Return whether the recorded kernels have completed. Note that this method // returns true if record() has not been called. bool completed() const; - bool recorded() const { - return recorded_; - } - private: - bool recorded_{false}; - std::shared_ptr event_; + CudaEventHandle event_; }; // Event that can synchronize between CPU and GPU. It is much slower than // CudaEvent so the latter should always be preferred when possible. -class SharedEvent { +class AtomicEvent { public: using Atomic = cuda::atomic; - SharedEvent(); + AtomicEvent(); void wait(uint64_t value); void wait(cudaStream_t stream, uint64_t value); @@ -57,7 +62,11 @@ class SharedEvent { uint64_t value() const; private: - std::shared_ptr buf_; + Atomic* atomic() const { + return static_cast(buf_->raw_ptr()); + } + + std::shared_ptr buf_; }; } // namespace mlx::core::cu diff --git a/mlx/backend/cuda/fence.cpp b/mlx/backend/cuda/fence.cpp index f399c4ebb..9121c7f4e 100644 --- a/mlx/backend/cuda/fence.cpp +++ b/mlx/backend/cuda/fence.cpp @@ -7,7 +7,7 @@ namespace mlx::core { struct FenceImpl { uint32_t count; - cu::SharedEvent event; + cu::AtomicEvent event; }; Fence::Fence(Stream s) { diff --git a/mlx/backend/cuda/worker.cpp b/mlx/backend/cuda/worker.cpp index ce211367c..349660ed4 100644 --- a/mlx/backend/cuda/worker.cpp +++ b/mlx/backend/cuda/worker.cpp @@ -7,6 +7,7 @@ namespace mlx::core::cu { Worker::Worker() : signal_stream_(device(mlx::core::Device::gpu)), + signal_event_(cudaEventDisableTiming | cudaEventBlockingSync), worker_(&Worker::thread_fn, this) {} Worker::~Worker() { diff --git a/mlx/backend/cuda/worker.h b/mlx/backend/cuda/worker.h index df6647e2b..7f1004f4d 100644 --- a/mlx/backend/cuda/worker.h +++ b/mlx/backend/cuda/worker.h @@ -3,7 +3,6 @@ #pragma once #include "mlx/backend/cuda/event.h" -#include "mlx/backend/cuda/utils.h" #include #include