mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Avoid invoking allocator::malloc when creating CUDA event (#2232)
This commit is contained in:
parent
0408ba0a76
commit
5685ceb3c7
@ -18,7 +18,10 @@ CudaAllocator::CudaAllocator()
|
|||||||
: buffer_cache_(
|
: buffer_cache_(
|
||||||
getpagesize(),
|
getpagesize(),
|
||||||
[](CudaBuffer* buf) { return buf->size; },
|
[](CudaBuffer* buf) { return buf->size; },
|
||||||
[this](CudaBuffer* buf) { cuda_free(buf); }) {
|
[this](CudaBuffer* buf) {
|
||||||
|
cuda_free(buf->data);
|
||||||
|
delete buf;
|
||||||
|
}) {
|
||||||
// TODO: Set memory limit for multi-device.
|
// TODO: Set memory limit for multi-device.
|
||||||
size_t free, total;
|
size_t free, total;
|
||||||
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
|
||||||
@ -70,7 +73,8 @@ void CudaAllocator::free(Buffer buffer) {
|
|||||||
buffer_cache_.recycle_to_cache(buf);
|
buffer_cache_.recycle_to_cache(buf);
|
||||||
} else {
|
} else {
|
||||||
lock.unlock();
|
lock.unlock();
|
||||||
cuda_free(buf);
|
cuda_free(buf->data);
|
||||||
|
delete buf;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -87,6 +91,25 @@ void CudaAllocator::register_this_thread() {
|
|||||||
allowed_threads_.insert(std::this_thread::get_id());
|
allowed_threads_.insert(std::this_thread::get_id());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void CudaAllocator::cuda_free(void* buf) {
|
||||||
|
// If cuda_free() is called from a unregistered thread, reschedule the call to
|
||||||
|
// worker.
|
||||||
|
{
|
||||||
|
std::lock_guard lock(worker_mutex_);
|
||||||
|
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
|
||||||
|
if (!worker_) {
|
||||||
|
worker_.reset(new Worker);
|
||||||
|
}
|
||||||
|
worker_->add_task([this, buf]() { this->cuda_free(buf); });
|
||||||
|
worker_->end_batch();
|
||||||
|
worker_->commit();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cudaFree(buf);
|
||||||
|
}
|
||||||
|
|
||||||
size_t CudaAllocator::get_active_memory() const {
|
size_t CudaAllocator::get_active_memory() const {
|
||||||
return active_memory_;
|
return active_memory_;
|
||||||
}
|
}
|
||||||
@ -125,26 +148,6 @@ void CudaAllocator::clear_cache() {
|
|||||||
buffer_cache_.clear();
|
buffer_cache_.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
void CudaAllocator::cuda_free(CudaBuffer* buf) {
|
|
||||||
// If cuda_free() is called from a unregistered thread, reschedule the call to
|
|
||||||
// worker.
|
|
||||||
{
|
|
||||||
std::lock_guard lock(worker_mutex_);
|
|
||||||
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
|
|
||||||
if (!worker_) {
|
|
||||||
worker_.reset(new Worker);
|
|
||||||
}
|
|
||||||
worker_->add_task([this, buf]() { this->cuda_free(buf); });
|
|
||||||
worker_->end_batch();
|
|
||||||
worker_->commit();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
cudaFree(buf->data);
|
|
||||||
delete buf;
|
|
||||||
}
|
|
||||||
|
|
||||||
CudaAllocator& allocator() {
|
CudaAllocator& allocator() {
|
||||||
// By creating the |allocator_| on heap, the destructor of CudaAllocator
|
// By creating the |allocator_| on heap, the destructor of CudaAllocator
|
||||||
// will not be called on exit and buffers in the cache will be leaked. This
|
// will not be called on exit and buffers in the cache will be leaked. This
|
||||||
|
@ -34,6 +34,9 @@ class CudaAllocator : public allocator::Allocator {
|
|||||||
// buffers there would result in dead lock.
|
// buffers there would result in dead lock.
|
||||||
void register_this_thread();
|
void register_this_thread();
|
||||||
|
|
||||||
|
// Call cudaFree in the safe thread.
|
||||||
|
void cuda_free(void* buf);
|
||||||
|
|
||||||
size_t get_active_memory() const;
|
size_t get_active_memory() const;
|
||||||
size_t get_peak_memory() const;
|
size_t get_peak_memory() const;
|
||||||
void reset_peak_memory();
|
void reset_peak_memory();
|
||||||
@ -47,8 +50,6 @@ class CudaAllocator : public allocator::Allocator {
|
|||||||
CudaAllocator();
|
CudaAllocator();
|
||||||
friend CudaAllocator& allocator();
|
friend CudaAllocator& allocator();
|
||||||
|
|
||||||
void cuda_free(CudaBuffer* buf);
|
|
||||||
|
|
||||||
std::mutex worker_mutex_;
|
std::mutex worker_mutex_;
|
||||||
std::unique_ptr<Worker> worker_;
|
std::unique_ptr<Worker> worker_;
|
||||||
std::set<std::thread::id> allowed_threads_;
|
std::set<std::thread::id> allowed_threads_;
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/allocator.h"
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/event.h"
|
#include "mlx/backend/cuda/event.h"
|
||||||
#include "mlx/backend/cuda/utils.h"
|
#include "mlx/backend/cuda/utils.h"
|
||||||
@ -111,12 +112,12 @@ __global__ void event_signal_kernel(SharedEvent::Atomic* ac, uint64_t value) {
|
|||||||
|
|
||||||
SharedEvent::SharedEvent() {
|
SharedEvent::SharedEvent() {
|
||||||
// Allocate cuda::atomic on managed memory.
|
// Allocate cuda::atomic on managed memory.
|
||||||
allocator::Buffer buffer = allocator::malloc(sizeof(Atomic));
|
Atomic* ac;
|
||||||
Atomic* ac = static_cast<Atomic*>(buffer.raw_ptr());
|
CHECK_CUDA_ERROR(cudaMallocManaged(&ac, sizeof(Atomic)));
|
||||||
new (ac) Atomic(0);
|
new (ac) Atomic(0);
|
||||||
ac_ = std::shared_ptr<Atomic>(ac, [buffer](Atomic* ptr) {
|
ac_ = std::shared_ptr<Atomic>(ac, [](Atomic* ptr) {
|
||||||
ptr->~Atomic();
|
ptr->~Atomic();
|
||||||
allocator::free(buffer);
|
allocator().cuda_free(ptr);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user