[CUDA] Simplify allocator (#2392)

* simplify allocator and fixe race with small pool

* Don't use shared event in worker

* use cuda buffer in small pool

* comment

* comment
This commit is contained in:
Awni Hannun 2025-07-22 08:24:01 -07:00 committed by GitHub
parent 74eccbf3fa
commit 1e496ddb82
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 100 additions and 162 deletions

View File

@ -2,7 +2,6 @@
#include "mlx/backend/cuda/allocator.h" #include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/utils.h" #include "mlx/backend/cuda/utils.h"
#include "mlx/backend/cuda/worker.h"
#include "mlx/utils.h" #include "mlx/utils.h"
#include <cuda_runtime.h> #include <cuda_runtime.h>
@ -25,52 +24,58 @@ constexpr int small_block_size = 8;
constexpr int small_pool_size = 4 * page_size; constexpr int small_pool_size = 4 * page_size;
SmallSizePool::SmallSizePool() { SmallSizePool::SmallSizePool() {
CHECK_CUDA_ERROR(cudaMallocManaged(&buffer_, small_pool_size));
end_ = reinterpret_cast<void*>(
reinterpret_cast<char*>(buffer_) + small_pool_size);
next_free_ = reinterpret_cast<Block*>(buffer_);
auto num_blocks = small_pool_size / small_block_size; auto num_blocks = small_pool_size / small_block_size;
buffer_ = new Block[num_blocks];
next_free_ = buffer_;
CHECK_CUDA_ERROR(cudaMallocManaged(&data_, small_pool_size));
CHECK_CUDA_ERROR(
cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetReadMostly, 0));
auto curr = next_free_; auto curr = next_free_;
for (size_t i = 0; i < num_blocks - 1; ++i) { for (size_t i = 1; i < num_blocks; ++i) {
curr->next = reinterpret_cast<Block*>( curr->next = buffer_ + i;
reinterpret_cast<char*>(buffer_) + (i + 1) * small_block_size);
curr = curr->next; curr = curr->next;
} }
curr->next = nullptr; curr->next = nullptr;
} }
SmallSizePool::~SmallSizePool() { SmallSizePool::~SmallSizePool() {
CHECK_CUDA_ERROR(cudaFree(buffer_)); CHECK_CUDA_ERROR(cudaFree(data_));
delete[] buffer_;
} }
void* SmallSizePool::malloc() { CudaBuffer* SmallSizePool::malloc() {
if (next_free_ == nullptr) { if (next_free_ == nullptr) {
return nullptr; return nullptr;
} }
Block* b = next_free_; Block* b = next_free_;
uint64_t i = next_free_ - buffer_;
next_free_ = next_free_->next; next_free_ = next_free_->next;
return static_cast<void*>(b); b->buf.data = static_cast<char*>(data_) + i * small_block_size;
b->buf.size = small_block_size;
return &b->buf;
} }
void SmallSizePool::free(void* p) { void SmallSizePool::free(CudaBuffer* buf) {
auto b = static_cast<Block*>(p); auto b = reinterpret_cast<Block*>(buf);
b->next = next_free_; b->next = next_free_;
next_free_ = b; next_free_ = b;
} }
bool SmallSizePool::in_pool(void* p) { bool SmallSizePool::in_pool(CudaBuffer* buf) {
return (p >= buffer_) && (p < end_); constexpr int num_blocks = (small_pool_size / small_block_size);
auto b = reinterpret_cast<Block*>(buf);
int64_t block_num = b - buffer_;
return block_num >= 0 && block_num < num_blocks;
} }
CudaAllocator::CudaAllocator() CudaAllocator::CudaAllocator()
: buffer_cache_( : buffer_cache_(
page_size, page_size,
[](CudaBuffer* buf) { return buf->size; }, [](CudaBuffer* buf) { return buf->size; },
[this](CudaBuffer* buf) { [this](CudaBuffer* buf) { cuda_free(buf); }) {
cuda_free(buf->data);
delete buf;
}) {
// TODO: Set memory limit for multi-device. // TODO: Set memory limit for multi-device.
size_t free, total; size_t free, total;
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total)); CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
@ -92,28 +97,26 @@ Buffer CudaAllocator::malloc(size_t size) {
CudaBuffer* buf = buffer_cache_.reuse_from_cache(size); CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
if (!buf) { if (!buf) {
// If we have a lot of memory pressure or are over the maximum cache size, // If we have a lot of memory pressure try to reclaim memory from the cache.
// try to reclaim memory from the cache. int64_t mem_to_free =
size_t mem_required = get_active_memory() + get_cache_memory() + size; get_active_memory() + get_cache_memory() + size - memory_limit_;
if (mem_required >= memory_limit_) { if (mem_to_free > 0) {
buffer_cache_.release_cached_buffers(mem_required - memory_limit_); buffer_cache_.release_cached_buffers(mem_to_free);
} }
lock.unlock();
buf = new CudaBuffer{nullptr, size};
// Try the scalar pool first // Try the scalar pool first
if (size <= small_block_size) { if (size <= small_block_size) {
buf->data = scalar_pool_.malloc(); buf = scalar_pool_.malloc();
} }
if (!buf->data) { lock.unlock();
if (!buf) {
buf = new CudaBuffer{nullptr, size};
cudaError_t err = cudaMallocManaged(&buf->data, size); cudaError_t err = cudaMallocManaged(&buf->data, size);
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) { if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
throw std::runtime_error(fmt::format( throw std::runtime_error(fmt::format(
"cudaMallocManaged failed: {}.", cudaGetErrorString(err))); "cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
} }
} }
lock.lock(); lock.lock();
} }
active_memory_ += size; active_memory_ += size;
@ -123,7 +126,6 @@ Buffer CudaAllocator::malloc(size_t size) {
if (get_cache_memory() > max_pool_size_) { if (get_cache_memory() > max_pool_size_) {
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_); buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
} }
return Buffer{buf}; return Buffer{buf};
} }
@ -138,9 +140,7 @@ void CudaAllocator::free(Buffer buffer) {
if (get_cache_memory() < max_pool_size_) { if (get_cache_memory() < max_pool_size_) {
buffer_cache_.recycle_to_cache(buf); buffer_cache_.recycle_to_cache(buf);
} else { } else {
lock.unlock(); cuda_free(buf);
cuda_free(buf->data);
delete buf;
} }
} }
@ -152,30 +152,13 @@ size_t CudaAllocator::size(Buffer buffer) const {
return buf->size; return buf->size;
} }
void CudaAllocator::register_this_thread() { // This must be called with mutex_ aquired
std::lock_guard lock(worker_mutex_); void CudaAllocator::cuda_free(CudaBuffer* buf) {
allowed_threads_.insert(std::this_thread::get_id());
}
void CudaAllocator::cuda_free(void* buf) {
// If cuda_free() is called from a unregistered thread, reschedule the call to
// worker.
{
std::lock_guard lock(worker_mutex_);
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
if (!worker_) {
worker_.reset(new Worker);
}
worker_->add_task([this, buf]() { this->cuda_free(buf); });
worker_->end_batch();
worker_->commit();
return;
}
}
if (scalar_pool_.in_pool(buf)) { if (scalar_pool_.in_pool(buf)) {
scalar_pool_.free(buf); scalar_pool_.free(buf);
} else { } else {
cudaFree(buf); cudaFree(buf->data);
delete buf;
} }
} }

View File

@ -7,13 +7,10 @@
#include <mutex> #include <mutex>
#include <set> #include <set>
#include <thread>
#include <utility> #include <utility>
namespace mlx::core::cu { namespace mlx::core::cu {
class Worker;
using allocator::Buffer; using allocator::Buffer;
// Stores cuda-managed unified memory. // Stores cuda-managed unified memory.
@ -24,13 +21,14 @@ struct CudaBuffer {
class SmallSizePool { class SmallSizePool {
private: private:
struct Block { union Block {
Block* next; Block* next;
CudaBuffer buf;
}; };
void* buffer_{nullptr}; Block* buffer_{nullptr};
void* data_{nullptr};
Block* next_free_{nullptr}; Block* next_free_{nullptr};
void* end_{nullptr};
public: public:
SmallSizePool(); SmallSizePool();
@ -39,9 +37,9 @@ class SmallSizePool {
SmallSizePool(const SmallSizePool&) = delete; SmallSizePool(const SmallSizePool&) = delete;
SmallSizePool& operator=(const SmallSizePool&) = delete; SmallSizePool& operator=(const SmallSizePool&) = delete;
void* malloc(); CudaBuffer* malloc();
void free(void* p); void free(CudaBuffer* buf);
bool in_pool(void* p); bool in_pool(CudaBuffer* buf);
}; };
class CudaAllocator : public allocator::Allocator { class CudaAllocator : public allocator::Allocator {
@ -50,15 +48,6 @@ class CudaAllocator : public allocator::Allocator {
void free(Buffer buffer) override; void free(Buffer buffer) override;
size_t size(Buffer buffer) const override; size_t size(Buffer buffer) const override;
// Register current thread as safe to free buffers.
// In cuda freeing a buffer implicitly synchronizes stream, and for threads
// that may be waited by gpu stream (for example cpu stream threads), freeing
// buffers there would result in dead lock.
void register_this_thread();
// Call cudaFree in the safe thread.
void cuda_free(void* buf);
size_t get_active_memory() const; size_t get_active_memory() const;
size_t get_peak_memory() const; size_t get_peak_memory() const;
void reset_peak_memory(); void reset_peak_memory();
@ -69,13 +58,11 @@ class CudaAllocator : public allocator::Allocator {
void clear_cache(); void clear_cache();
private: private:
void cuda_free(CudaBuffer* buf);
CudaAllocator(); CudaAllocator();
friend CudaAllocator& allocator(); friend CudaAllocator& allocator();
std::mutex worker_mutex_;
std::unique_ptr<Worker> worker_;
std::set<std::thread::id> allowed_threads_;
std::mutex mutex_; std::mutex mutex_;
size_t memory_limit_; size_t memory_limit_;
size_t max_pool_size_; size_t max_pool_size_;

View File

@ -306,7 +306,6 @@ void CommandEncoder::commit() {
} }
// Put completion handlers in a batch. // Put completion handlers in a batch.
worker_.end_batch();
worker_.commit(stream_); worker_.commit(stream_);
} }
@ -315,7 +314,6 @@ void CommandEncoder::synchronize() {
auto p = std::make_shared<std::promise<void>>(); auto p = std::make_shared<std::promise<void>>();
std::future<void> f = p->get_future(); std::future<void> f = p->get_future();
add_completed_handler([p = std::move(p)]() { p->set_value(); }); add_completed_handler([p = std::move(p)]() { p->set_value(); });
worker_.end_batch();
commit(); commit();
f.wait(); f.wait();
} }

View File

@ -19,8 +19,6 @@ void new_stream(Stream s) {
cudaFree(nullptr); cudaFree(nullptr);
// Ensure the static stream objects get created. // Ensure the static stream objects get created.
cu::get_command_encoder(s); cu::get_command_encoder(s);
// The main thread is safe to free buffers.
cu::allocator().register_this_thread();
} }
void eval(array& arr) { void eval(array& arr) {

View File

@ -110,24 +110,26 @@ __global__ void event_signal_kernel(SharedEvent::Atomic* ac, uint64_t value) {
event_signal(ac, value); event_signal(ac, value);
} }
SharedEvent::Atomic* to_atomic(std::shared_ptr<Buffer> buf) {
return static_cast<SharedEvent::Atomic*>(buf->raw_ptr());
}
SharedEvent::SharedEvent() { SharedEvent::SharedEvent() {
// Allocate cuda::atomic on managed memory. buf_ = std::shared_ptr<Buffer>(
Atomic* ac; new Buffer{allocator().malloc(sizeof(Atomic))}, [](Buffer* ptr) {
CHECK_CUDA_ERROR(cudaMallocManaged(&ac, sizeof(Atomic))); allocator().free(*ptr);
new (ac) Atomic(0); delete ptr;
ac_ = std::shared_ptr<Atomic>(ac, [](Atomic* ptr) { });
ptr->~Atomic(); *static_cast<uint64_t*>(buf_->raw_ptr()) = 0;
allocator().cuda_free(ptr);
});
} }
void SharedEvent::wait(uint64_t value) { void SharedEvent::wait(uint64_t value) {
nvtx3::scoped_range r("cu::SharedEvent::wait"); nvtx3::scoped_range r("cu::SharedEvent::wait");
event_wait(ac_.get(), value); event_wait(to_atomic(buf_), value);
} }
void SharedEvent::wait(cudaStream_t stream, uint64_t value) { void SharedEvent::wait(cudaStream_t stream, uint64_t value) {
event_wait_kernel<<<1, 1, 0, stream>>>(ac_.get(), value); event_wait_kernel<<<1, 1, 0, stream>>>(to_atomic(buf_), value);
} }
void SharedEvent::wait(Stream s, uint64_t value) { void SharedEvent::wait(Stream s, uint64_t value) {
@ -138,17 +140,17 @@ void SharedEvent::wait(Stream s, uint64_t value) {
auto& encoder = get_command_encoder(s); auto& encoder = get_command_encoder(s);
encoder.commit(); encoder.commit();
wait(encoder.stream(), value); wait(encoder.stream(), value);
encoder.add_completed_handler([ac = ac_]() {}); encoder.add_completed_handler([buf = buf_]() {});
} }
} }
void SharedEvent::signal(uint64_t value) { void SharedEvent::signal(uint64_t value) {
nvtx3::scoped_range r("cu::SharedEvent::signal"); nvtx3::scoped_range r("cu::SharedEvent::signal");
event_signal(ac_.get(), value); event_signal(to_atomic(buf_), value);
} }
void SharedEvent::signal(cudaStream_t stream, uint64_t value) { void SharedEvent::signal(cudaStream_t stream, uint64_t value) {
event_signal_kernel<<<1, 1, 0, stream>>>(ac_.get(), value); event_signal_kernel<<<1, 1, 0, stream>>>(to_atomic(buf_), value);
} }
void SharedEvent::signal(Stream s, uint64_t value) { void SharedEvent::signal(Stream s, uint64_t value) {
@ -162,18 +164,18 @@ void SharedEvent::signal(Stream s, uint64_t value) {
auto& encoder = get_command_encoder(s); auto& encoder = get_command_encoder(s);
encoder.commit(); encoder.commit();
signal(encoder.stream(), value); signal(encoder.stream(), value);
encoder.add_completed_handler([ac = ac_]() {}); encoder.add_completed_handler([buf = buf_]() {});
} }
} }
bool SharedEvent::is_signaled(uint64_t value) const { bool SharedEvent::is_signaled(uint64_t value) const {
nvtx3::scoped_range r("cu::SharedEvent::is_signaled"); nvtx3::scoped_range r("cu::SharedEvent::is_signaled");
return ac_->load() >= value; return to_atomic(buf_)->load() >= value;
} }
uint64_t SharedEvent::value() const { uint64_t SharedEvent::value() const {
nvtx3::scoped_range r("cu::SharedEvent::value"); nvtx3::scoped_range r("cu::SharedEvent::value");
return ac_->load(); return to_atomic(buf_)->load();
} }
} // namespace cu } // namespace cu

View File

@ -2,6 +2,7 @@
#pragma once #pragma once
#include "mlx/allocator.h"
#include "mlx/stream.h" #include "mlx/stream.h"
#include <cuda_runtime.h> #include <cuda_runtime.h>
@ -55,12 +56,8 @@ class SharedEvent {
bool is_signaled(uint64_t value) const; bool is_signaled(uint64_t value) const;
uint64_t value() const; uint64_t value() const;
const std::shared_ptr<Atomic>& atomic() const {
return ac_;
}
private: private:
std::shared_ptr<Atomic> ac_; std::shared_ptr<mlx::core::allocator::Buffer> buf_;
}; };
} // namespace mlx::core::cu } // namespace mlx::core::cu

View File

@ -1,7 +1,6 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/worker.h" #include "mlx/backend/cuda/worker.h"
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
namespace mlx::core::cu { namespace mlx::core::cu {
@ -12,10 +11,10 @@ Worker::Worker()
Worker::~Worker() { Worker::~Worker() {
{ {
std::lock_guard lock(worker_mutex_); std::lock_guard lock(mtx_);
stop_ = true; stop_ = true;
} }
worker_event_.signal(batch_ + 1); cond_.notify_one();
worker_.join(); worker_.join();
} }
@ -23,53 +22,41 @@ void Worker::add_task(std::function<void()> task) {
pending_tasks_.push_back(std::move(task)); pending_tasks_.push_back(std::move(task));
} }
void Worker::consume_in_this_thread() { void Worker::signal(void* data) {
for (auto& task : pending_tasks_) { auto w = static_cast<Worker*>(data);
task();
}
pending_tasks_.clear();
}
void Worker::end_batch() {
batch_++;
{ {
std::lock_guard lock(worker_mutex_); std::lock_guard lock(w->mtx_);
worker_tasks_[batch_] = std::move(pending_tasks_); w->signaled_batch_++;
} }
uncommited_batches_++; w->cond_.notify_one();
}
void Worker::commit() {
if (uncommited_batches_ == 0) {
return;
}
uncommited_batches_ = 0;
worker_event_.signal(batch_);
} }
void Worker::commit(cudaStream_t stream) { void Worker::commit(cudaStream_t stream) {
if (uncommited_batches_ == 0) { // Move pending tasks into tasks
if (pending_tasks_.empty()) {
return; return;
} }
uncommited_batches_ = 0; {
// Signal the |worker_event_| in |signal_stream_| after the kernels in std::lock_guard lock(mtx_);
// |stream_| finish running. // Move pending tasks into ready tasks
worker_tasks_[++committed_batch_] = std::move(pending_tasks_);
}
signal_event_.record(stream); signal_event_.record(stream);
signal_event_.wait(signal_stream_); signal_event_.wait(signal_stream_);
worker_event_.signal(signal_stream_, batch_); cudaLaunchHostFunc(signal_stream_, signal, this);
} }
void Worker::thread_fn() { void Worker::thread_fn() {
// The worker thread is safe to free buffers.
allocator().register_this_thread();
while (!stop_) { while (!stop_) {
uint64_t batch = worker_event_.value(); uint64_t current_batch = 0;
Tasks tasks; Tasks tasks;
{ {
std::lock_guard lock(worker_mutex_); std::unique_lock<std::mutex> lk(mtx_);
// Move tasks in signaled batches. cond_.wait(lk, [this, &current_batch] {
auto end = worker_tasks_.upper_bound(batch); return this->signaled_batch_ > current_batch || this->stop_;
});
current_batch = signaled_batch_;
auto end = worker_tasks_.upper_bound(current_batch);
for (auto it = worker_tasks_.begin(); it != end; ++it) { for (auto it = worker_tasks_.begin(); it != end; ++it) {
if (tasks.empty()) { if (tasks.empty()) {
tasks = std::move(it->second); tasks = std::move(it->second);
@ -85,7 +72,6 @@ void Worker::thread_fn() {
auto task = std::move(tasks[i]); auto task = std::move(tasks[i]);
task(); task();
} }
worker_event_.wait(batch + 1);
} }
} }

View File

@ -5,6 +5,7 @@
#include "mlx/backend/cuda/event.h" #include "mlx/backend/cuda/event.h"
#include "mlx/backend/cuda/utils.h" #include "mlx/backend/cuda/utils.h"
#include <condition_variable>
#include <functional> #include <functional>
#include <map> #include <map>
#include <mutex> #include <mutex>
@ -24,38 +25,24 @@ class Worker {
// Add a pending |task| that will run when consumed or commited. // Add a pending |task| that will run when consumed or commited.
void add_task(std::function<void()> task); void add_task(std::function<void()> task);
// Run pending tasks immediately in current thread.
void consume_in_this_thread();
// Put pending tasks in a batch.
void end_batch();
// Inform worker thread to run current batches now.
void commit();
// Inform worker thread to run current batches after kernels in |stream| // Inform worker thread to run current batches after kernels in |stream|
// finish running. // finish running.
void commit(cudaStream_t stream); void commit(cudaStream_t stream);
// Return how many batches have been added but not committed yet.
size_t uncommited_batches() const {
return uncommited_batches_;
}
private: private:
void thread_fn(); static void signal(void*);
uint64_t batch_{0}; void thread_fn();
size_t uncommited_batches_{0}; std::mutex mtx_;
std::condition_variable cond_;
uint64_t committed_batch_{0};
uint64_t signaled_batch_{0};
// Cuda stream and event for signaling kernel completion. // Cuda stream and event for signaling kernel completion.
CudaStream signal_stream_; CudaStream signal_stream_;
CudaEvent signal_event_; CudaEvent signal_event_;
// Worker thread.
SharedEvent worker_event_;
std::thread worker_;
std::mutex worker_mutex_;
bool stop_{false}; bool stop_{false};
// Tasks are put in |pending_tasks_| first, and then moved to // Tasks are put in |pending_tasks_| first, and then moved to
@ -63,6 +50,7 @@ class Worker {
using Tasks = std::vector<std::function<void()>>; using Tasks = std::vector<std::function<void()>>;
Tasks pending_tasks_; Tasks pending_tasks_;
std::map<uint64_t, Tasks> worker_tasks_; std::map<uint64_t, Tasks> worker_tasks_;
std::thread worker_;
}; };
} // namespace mlx::core::cu } // namespace mlx::core::cu

View File

@ -128,8 +128,7 @@ Buffer MetalAllocator::malloc(size_t size) {
auto pool = metal::new_scoped_memory_pool(); auto pool = metal::new_scoped_memory_pool();
// If we have a lot of memory pressure or are over the maximum cache size, // If we have a lot of memory pressure try to reclaim memory from the cache
// try to reclaim memory from the cache
if (mem_required >= gc_limit_ || num_resources_ >= resource_limit_) { if (mem_required >= gc_limit_ || num_resources_ >= resource_limit_) {
num_resources_ -= num_resources_ -=
buffer_cache_.release_cached_buffers(mem_required - gc_limit_); buffer_cache_.release_cached_buffers(mem_required - gc_limit_);