diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index 767170848..27e702743 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -68,8 +68,8 @@ Device::~Device() { void Device::make_current() { // We need to set/get current CUDA device very frequently, cache it to reduce - // actual calls of CUDA APIs. This function assumes single-thread in host. - static int current = 0; + // actual calls of CUDA APIs. + static thread_local int current = 0; if (current != device_) { CHECK_CUDA_ERROR(cudaSetDevice(device_)); current = device_; @@ -196,6 +196,7 @@ CommandEncoder::CommandEncoder(Device& d) : device_(d), stream_(d), graph_(d), + worker_(d), graph_cache_("MLX_CUDA_GRAPH_CACHE_SIZE", /* default_capacity */ 400) {} void CommandEncoder::add_completed_handler(std::function task) { diff --git a/mlx/backend/cuda/device.h b/mlx/backend/cuda/device.h index 3526de947..93667e736 100644 --- a/mlx/backend/cuda/device.h +++ b/mlx/backend/cuda/device.h @@ -140,7 +140,7 @@ class Device { Device(const Device&) = delete; Device& operator=(const Device&) = delete; - // Make this device the current cuda device, required by some cuda calls. + // Make this device the current cuda device, this method is thread-safe. void make_current(); CommandEncoder& get_command_encoder(Stream s); diff --git a/mlx/backend/cuda/eval.cpp b/mlx/backend/cuda/eval.cpp index b91fafa22..379d65423 100644 --- a/mlx/backend/cuda/eval.cpp +++ b/mlx/backend/cuda/eval.cpp @@ -15,9 +15,10 @@ bool is_available() { } void new_stream(Stream s) { - // Force initalization of CUDA by creating an event, so the CUDA runtime and - // our CUDA event pool get destroyed last. - cu::CudaEvent(cudaEventDefault); + // Force initalization of CUDA, so CUDA runtime get destroyed at last. + cudaFree(nullptr); + // Make sure CUDA event pool get destroyed after device and stream. + cu::CudaEvent::init_pool(); // Ensure the static stream objects get created. cu::get_command_encoder(s); } diff --git a/mlx/backend/cuda/event.cu b/mlx/backend/cuda/event.cu index 2d9c96604..55b8d1e64 100644 --- a/mlx/backend/cuda/event.cu +++ b/mlx/backend/cuda/event.cu @@ -22,11 +22,15 @@ namespace cu { namespace { // Manage cached cudaEvent_t objects. -struct CudaEventPool { - static CudaEventHandle create(int flags) { - auto& cache = cache_for(flags); +class CudaEventPool { + public: + CudaEventHandle create(Device& d, int flags) { + if (!on_creation_thread()) { + return CudaEventHandle(d, flags); + } + auto& cache = cache_for(d, flags); if (cache.empty()) { - return CudaEventHandle(flags); + return CudaEventHandle(d, flags); } else { CudaEventHandle ret = std::move(cache.back()); cache.pop_back(); @@ -34,54 +38,89 @@ struct CudaEventPool { } } - static void release(CudaEventHandle event) { - cache_for(event.flags).push_back(std::move(event)); + void release(CudaEventHandle event) { + if (!on_creation_thread()) { + // Event will be destroyed directly instead of getting moved to cache. + return; + } + cache_for(event.device, event.flags).push_back(std::move(event)); } - static std::vector& cache_for(int flags) { - static std::map> cache; - return cache[flags]; + private: + std::vector& cache_for(Device& d, int flags) { + return cache_[d.cuda_device()][flags]; } + + bool on_creation_thread() { + return std::this_thread::get_id() == thread_id_; + } + + // The CudaEvent may be created and destroyed on different threads (for + // example when waiting on GPU work in CPU stream), we don't want to make + // the cache thread-safe as it adds overhead, so we just skip cache when + // using events in worker threads. + std::thread::id thread_id_{std::this_thread::get_id()}; + + // {device: {flags: [events]}} + std::map>> cache_; }; +CudaEventPool& cuda_event_pool() { + static CudaEventPool pool; + return pool; +} + } // namespace -CudaEventHandle::CudaEventHandle(int flags) : flags(flags) { +CudaEventHandle::CudaEventHandle(Device& d, int flags) + : device(d), flags(flags) { + device.make_current(); CHECK_CUDA_ERROR(cudaEventCreateWithFlags(&handle_, flags)); assert(handle_ != nullptr); } -CudaEvent::CudaEvent(int flags) : event_(CudaEventPool::create(flags)) {} +CudaEvent::CudaEvent(Device& d, int flags) + : event_(cuda_event_pool().create(d, flags)) {} CudaEvent::~CudaEvent() { - CudaEventPool::release(std::move(event_)); + cuda_event_pool().release(std::move(event_)); } void CudaEvent::wait() { nvtx3::scoped_range r("cu::CudaEvent::wait"); + event_.device.make_current(); cudaEventSynchronize(event_); } void CudaEvent::wait(cudaStream_t stream) { + event_.device.make_current(); cudaStreamWaitEvent(stream, event_); } void CudaEvent::record(cudaStream_t stream) { + event_.device.make_current(); cudaEventRecord(event_, stream); } bool CudaEvent::completed() const { + // Note: cudaEventQuery can be safely called from any device. return cudaEventQuery(event_) == cudaSuccess; } +// static +void CudaEvent::init_pool() { + cuda_event_pool(); +} + // Wraps CudaEvent with a few features: // 1. The class can be copied. // 2. Make wait/record work with CPU streams. // 3. Add checks for waiting on un-recorded event. class CopyableCudaEvent { public: - CopyableCudaEvent() + explicit CopyableCudaEvent(Device& d) : event_(std::make_shared( + d, cudaEventDisableTiming | cudaEventBlockingSync)) {} void wait() { @@ -245,7 +284,7 @@ struct EventImpl { nvtx3::mark("Using slow AtomicEvent"); atomic = std::make_unique(); } else { - cuda = std::make_unique(); + cuda = std::make_unique(cu::device(s.device)); } } }; diff --git a/mlx/backend/cuda/event.h b/mlx/backend/cuda/event.h index 7a9ea66f6..342e6ae3f 100644 --- a/mlx/backend/cuda/event.h +++ b/mlx/backend/cuda/event.h @@ -13,9 +13,12 @@ namespace mlx::core::cu { +class Device; + // RAII-managed move-only wrapper of cudaEvent_t. struct CudaEventHandle : public CudaHandle { - CudaEventHandle(int flags); + CudaEventHandle(Device& d, int flags); + Device& device; int flags; }; @@ -23,7 +26,7 @@ struct CudaEventHandle : public CudaHandle { // on GPU stream in CPU stream, but can not wait on CPU stream. class CudaEvent { public: - explicit CudaEvent(int flags); + CudaEvent(Device& d, int flags); ~CudaEvent(); CudaEvent(CudaEvent&&) = default; @@ -40,6 +43,9 @@ class CudaEvent { // returns true if record() has not been called. bool completed() const; + // Internal: make sure event pool is initialized. + static void init_pool(); + private: CudaEventHandle event_; }; diff --git a/mlx/backend/cuda/worker.cpp b/mlx/backend/cuda/worker.cpp index 349660ed4..c468b8501 100644 --- a/mlx/backend/cuda/worker.cpp +++ b/mlx/backend/cuda/worker.cpp @@ -5,9 +5,9 @@ namespace mlx::core::cu { -Worker::Worker() - : signal_stream_(device(mlx::core::Device::gpu)), - signal_event_(cudaEventDisableTiming | cudaEventBlockingSync), +Worker::Worker(Device& d) + : signal_stream_(d), + signal_event_(d, cudaEventDisableTiming | cudaEventBlockingSync), worker_(&Worker::thread_fn, this) {} Worker::~Worker() { diff --git a/mlx/backend/cuda/worker.h b/mlx/backend/cuda/worker.h index 7f1004f4d..8f05e7b97 100644 --- a/mlx/backend/cuda/worker.h +++ b/mlx/backend/cuda/worker.h @@ -15,7 +15,7 @@ namespace mlx::core::cu { // Run tasks in worker thread, synchronized with cuda stream. class Worker { public: - Worker(); + explicit Worker(Device& d); ~Worker(); Worker(const Worker&) = delete;