From 5685ceb3c79618fcda983b2f3657bf9528c64220 Mon Sep 17 00:00:00 2001 From: Cheng Date: Wed, 4 Jun 2025 08:48:40 +0900 Subject: [PATCH] Avoid invoking allocator::malloc when creating CUDA event (#2232) --- mlx/backend/cuda/allocator.cpp | 47 ++++++++++++++++++---------------- mlx/backend/cuda/allocator.h | 5 ++-- mlx/backend/cuda/event.cu | 9 ++++--- 3 files changed, 33 insertions(+), 28 deletions(-) diff --git a/mlx/backend/cuda/allocator.cpp b/mlx/backend/cuda/allocator.cpp index 86af3a774..00f78fd4f 100644 --- a/mlx/backend/cuda/allocator.cpp +++ b/mlx/backend/cuda/allocator.cpp @@ -18,7 +18,10 @@ CudaAllocator::CudaAllocator() : buffer_cache_( getpagesize(), [](CudaBuffer* buf) { return buf->size; }, - [this](CudaBuffer* buf) { cuda_free(buf); }) { + [this](CudaBuffer* buf) { + cuda_free(buf->data); + delete buf; + }) { // TODO: Set memory limit for multi-device. size_t free, total; CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total)); @@ -70,7 +73,8 @@ void CudaAllocator::free(Buffer buffer) { buffer_cache_.recycle_to_cache(buf); } else { lock.unlock(); - cuda_free(buf); + cuda_free(buf->data); + delete buf; } } @@ -87,6 +91,25 @@ void CudaAllocator::register_this_thread() { allowed_threads_.insert(std::this_thread::get_id()); } +void CudaAllocator::cuda_free(void* buf) { + // If cuda_free() is called from a unregistered thread, reschedule the call to + // worker. + { + std::lock_guard lock(worker_mutex_); + if (allowed_threads_.count(std::this_thread::get_id()) == 0) { + if (!worker_) { + worker_.reset(new Worker); + } + worker_->add_task([this, buf]() { this->cuda_free(buf); }); + worker_->end_batch(); + worker_->commit(); + return; + } + } + + cudaFree(buf); +} + size_t CudaAllocator::get_active_memory() const { return active_memory_; } @@ -125,26 +148,6 @@ void CudaAllocator::clear_cache() { buffer_cache_.clear(); } -void CudaAllocator::cuda_free(CudaBuffer* buf) { - // If cuda_free() is called from a unregistered thread, reschedule the call to - // worker. - { - std::lock_guard lock(worker_mutex_); - if (allowed_threads_.count(std::this_thread::get_id()) == 0) { - if (!worker_) { - worker_.reset(new Worker); - } - worker_->add_task([this, buf]() { this->cuda_free(buf); }); - worker_->end_batch(); - worker_->commit(); - return; - } - } - - cudaFree(buf->data); - delete buf; -} - CudaAllocator& allocator() { // By creating the |allocator_| on heap, the destructor of CudaAllocator // will not be called on exit and buffers in the cache will be leaked. This diff --git a/mlx/backend/cuda/allocator.h b/mlx/backend/cuda/allocator.h index fe3755121..e268c6334 100644 --- a/mlx/backend/cuda/allocator.h +++ b/mlx/backend/cuda/allocator.h @@ -34,6 +34,9 @@ class CudaAllocator : public allocator::Allocator { // buffers there would result in dead lock. void register_this_thread(); + // Call cudaFree in the safe thread. + void cuda_free(void* buf); + size_t get_active_memory() const; size_t get_peak_memory() const; void reset_peak_memory(); @@ -47,8 +50,6 @@ class CudaAllocator : public allocator::Allocator { CudaAllocator(); friend CudaAllocator& allocator(); - void cuda_free(CudaBuffer* buf); - std::mutex worker_mutex_; std::unique_ptr worker_; std::set allowed_threads_; diff --git a/mlx/backend/cuda/event.cu b/mlx/backend/cuda/event.cu index a487f45b4..f462720a9 100644 --- a/mlx/backend/cuda/event.cu +++ b/mlx/backend/cuda/event.cu @@ -1,5 +1,6 @@ // Copyright © 2024 Apple Inc. +#include "mlx/backend/cuda/allocator.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/event.h" #include "mlx/backend/cuda/utils.h" @@ -111,12 +112,12 @@ __global__ void event_signal_kernel(SharedEvent::Atomic* ac, uint64_t value) { SharedEvent::SharedEvent() { // Allocate cuda::atomic on managed memory. - allocator::Buffer buffer = allocator::malloc(sizeof(Atomic)); - Atomic* ac = static_cast(buffer.raw_ptr()); + Atomic* ac; + CHECK_CUDA_ERROR(cudaMallocManaged(&ac, sizeof(Atomic))); new (ac) Atomic(0); - ac_ = std::shared_ptr(ac, [buffer](Atomic* ptr) { + ac_ = std::shared_ptr(ac, [](Atomic* ptr) { ptr->~Atomic(); - allocator::free(buffer); + allocator().cuda_free(ptr); }); }