Avoid invoking allocator::malloc when creating CUDA event (#2232)

This commit is contained in:
Cheng 2025-06-04 08:48:40 +09:00 committed by GitHub
parent 0408ba0a76
commit 5685ceb3c7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 33 additions and 28 deletions

View File

@ -18,7 +18,10 @@ CudaAllocator::CudaAllocator()
: buffer_cache_( : buffer_cache_(
getpagesize(), getpagesize(),
[](CudaBuffer* buf) { return buf->size; }, [](CudaBuffer* buf) { return buf->size; },
[this](CudaBuffer* buf) { cuda_free(buf); }) { [this](CudaBuffer* buf) {
cuda_free(buf->data);
delete buf;
}) {
// TODO: Set memory limit for multi-device. // TODO: Set memory limit for multi-device.
size_t free, total; size_t free, total;
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total)); CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
@ -70,7 +73,8 @@ void CudaAllocator::free(Buffer buffer) {
buffer_cache_.recycle_to_cache(buf); buffer_cache_.recycle_to_cache(buf);
} else { } else {
lock.unlock(); lock.unlock();
cuda_free(buf); cuda_free(buf->data);
delete buf;
} }
} }
@ -87,6 +91,25 @@ void CudaAllocator::register_this_thread() {
allowed_threads_.insert(std::this_thread::get_id()); allowed_threads_.insert(std::this_thread::get_id());
} }
void CudaAllocator::cuda_free(void* buf) {
// If cuda_free() is called from a unregistered thread, reschedule the call to
// worker.
{
std::lock_guard lock(worker_mutex_);
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
if (!worker_) {
worker_.reset(new Worker);
}
worker_->add_task([this, buf]() { this->cuda_free(buf); });
worker_->end_batch();
worker_->commit();
return;
}
}
cudaFree(buf);
}
size_t CudaAllocator::get_active_memory() const { size_t CudaAllocator::get_active_memory() const {
return active_memory_; return active_memory_;
} }
@ -125,26 +148,6 @@ void CudaAllocator::clear_cache() {
buffer_cache_.clear(); buffer_cache_.clear();
} }
void CudaAllocator::cuda_free(CudaBuffer* buf) {
// If cuda_free() is called from a unregistered thread, reschedule the call to
// worker.
{
std::lock_guard lock(worker_mutex_);
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
if (!worker_) {
worker_.reset(new Worker);
}
worker_->add_task([this, buf]() { this->cuda_free(buf); });
worker_->end_batch();
worker_->commit();
return;
}
}
cudaFree(buf->data);
delete buf;
}
CudaAllocator& allocator() { CudaAllocator& allocator() {
// By creating the |allocator_| on heap, the destructor of CudaAllocator // By creating the |allocator_| on heap, the destructor of CudaAllocator
// will not be called on exit and buffers in the cache will be leaked. This // will not be called on exit and buffers in the cache will be leaked. This

View File

@ -34,6 +34,9 @@ class CudaAllocator : public allocator::Allocator {
// buffers there would result in dead lock. // buffers there would result in dead lock.
void register_this_thread(); void register_this_thread();
// Call cudaFree in the safe thread.
void cuda_free(void* buf);
size_t get_active_memory() const; size_t get_active_memory() const;
size_t get_peak_memory() const; size_t get_peak_memory() const;
void reset_peak_memory(); void reset_peak_memory();
@ -47,8 +50,6 @@ class CudaAllocator : public allocator::Allocator {
CudaAllocator(); CudaAllocator();
friend CudaAllocator& allocator(); friend CudaAllocator& allocator();
void cuda_free(CudaBuffer* buf);
std::mutex worker_mutex_; std::mutex worker_mutex_;
std::unique_ptr<Worker> worker_; std::unique_ptr<Worker> worker_;
std::set<std::thread::id> allowed_threads_; std::set<std::thread::id> allowed_threads_;

View File

@ -1,5 +1,6 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/event.h" #include "mlx/backend/cuda/event.h"
#include "mlx/backend/cuda/utils.h" #include "mlx/backend/cuda/utils.h"
@ -111,12 +112,12 @@ __global__ void event_signal_kernel(SharedEvent::Atomic* ac, uint64_t value) {
SharedEvent::SharedEvent() { SharedEvent::SharedEvent() {
// Allocate cuda::atomic on managed memory. // Allocate cuda::atomic on managed memory.
allocator::Buffer buffer = allocator::malloc(sizeof(Atomic)); Atomic* ac;
Atomic* ac = static_cast<Atomic*>(buffer.raw_ptr()); CHECK_CUDA_ERROR(cudaMallocManaged(&ac, sizeof(Atomic)));
new (ac) Atomic(0); new (ac) Atomic(0);
ac_ = std::shared_ptr<Atomic>(ac, [buffer](Atomic* ptr) { ac_ = std::shared_ptr<Atomic>(ac, [](Atomic* ptr) {
ptr->~Atomic(); ptr->~Atomic();
allocator::free(buffer); allocator().cuda_free(ptr);
}); });
} }