simplify allocator and fixe race with small pool

This commit is contained in:
Awni Hannun
2025-07-19 10:23:59 -07:00
parent 74eccbf3fa
commit b62368f292
8 changed files with 31 additions and 97 deletions

View File

@@ -94,18 +94,21 @@ Buffer CudaAllocator::malloc(size_t size) {
if (!buf) { if (!buf) {
// If we have a lot of memory pressure or are over the maximum cache size, // If we have a lot of memory pressure or are over the maximum cache size,
// try to reclaim memory from the cache. // try to reclaim memory from the cache.
size_t mem_required = get_active_memory() + get_cache_memory() + size; int64_t mem_to_free =
if (mem_required >= memory_limit_) { get_active_memory() + get_cache_memory() + size - memory_limit_;
buffer_cache_.release_cached_buffers(mem_required - memory_limit_); mem_to_free = std::max(
static_cast<int64_t>(get_cache_memory() - max_pool_size_), mem_to_free);
if (mem_to_free > 0) {
buffer_cache_.release_cached_buffers(mem_to_free);
} }
lock.unlock();
buf = new CudaBuffer{nullptr, size}; buf = new CudaBuffer{nullptr, size};
// Try the scalar pool first // Try the scalar pool first
if (size <= small_block_size) { if (size <= small_block_size) {
buf->data = scalar_pool_.malloc(); buf->data = scalar_pool_.malloc();
} }
lock.unlock();
if (!buf->data) { if (!buf->data) {
cudaError_t err = cudaMallocManaged(&buf->data, size); cudaError_t err = cudaMallocManaged(&buf->data, size);
if (err != cudaSuccess && err != cudaErrorMemoryAllocation) { if (err != cudaSuccess && err != cudaErrorMemoryAllocation) {
@@ -113,7 +116,6 @@ Buffer CudaAllocator::malloc(size_t size) {
"cudaMallocManaged failed: {}.", cudaGetErrorString(err))); "cudaMallocManaged failed: {}.", cudaGetErrorString(err)));
} }
} }
lock.lock(); lock.lock();
} }
active_memory_ += size; active_memory_ += size;
@@ -123,7 +125,6 @@ Buffer CudaAllocator::malloc(size_t size) {
if (get_cache_memory() > max_pool_size_) { if (get_cache_memory() > max_pool_size_) {
buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_); buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
} }
return Buffer{buf}; return Buffer{buf};
} }
@@ -135,13 +136,7 @@ void CudaAllocator::free(Buffer buffer) {
std::unique_lock lock(mutex_); std::unique_lock lock(mutex_);
active_memory_ -= buf->size; active_memory_ -= buf->size;
if (get_cache_memory() < max_pool_size_) { buffer_cache_.recycle_to_cache(buf);
buffer_cache_.recycle_to_cache(buf);
} else {
lock.unlock();
cuda_free(buf->data);
delete buf;
}
} }
size_t CudaAllocator::size(Buffer buffer) const { size_t CudaAllocator::size(Buffer buffer) const {
@@ -152,26 +147,8 @@ size_t CudaAllocator::size(Buffer buffer) const {
return buf->size; return buf->size;
} }
void CudaAllocator::register_this_thread() { // This must be called with mutex_ aquired
std::lock_guard lock(worker_mutex_);
allowed_threads_.insert(std::this_thread::get_id());
}
void CudaAllocator::cuda_free(void* buf) { void CudaAllocator::cuda_free(void* buf) {
// If cuda_free() is called from a unregistered thread, reschedule the call to
// worker.
{
std::lock_guard lock(worker_mutex_);
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
if (!worker_) {
worker_.reset(new Worker);
}
worker_->add_task([this, buf]() { this->cuda_free(buf); });
worker_->end_batch();
worker_->commit();
return;
}
}
if (scalar_pool_.in_pool(buf)) { if (scalar_pool_.in_pool(buf)) {
scalar_pool_.free(buf); scalar_pool_.free(buf);
} else { } else {

View File

@@ -7,13 +7,10 @@
#include <mutex> #include <mutex>
#include <set> #include <set>
#include <thread>
#include <utility> #include <utility>
namespace mlx::core::cu { namespace mlx::core::cu {
class Worker;
using allocator::Buffer; using allocator::Buffer;
// Stores cuda-managed unified memory. // Stores cuda-managed unified memory.
@@ -50,15 +47,6 @@ class CudaAllocator : public allocator::Allocator {
void free(Buffer buffer) override; void free(Buffer buffer) override;
size_t size(Buffer buffer) const override; size_t size(Buffer buffer) const override;
// Register current thread as safe to free buffers.
// In cuda freeing a buffer implicitly synchronizes stream, and for threads
// that may be waited by gpu stream (for example cpu stream threads), freeing
// buffers there would result in dead lock.
void register_this_thread();
// Call cudaFree in the safe thread.
void cuda_free(void* buf);
size_t get_active_memory() const; size_t get_active_memory() const;
size_t get_peak_memory() const; size_t get_peak_memory() const;
void reset_peak_memory(); void reset_peak_memory();
@@ -69,13 +57,11 @@ class CudaAllocator : public allocator::Allocator {
void clear_cache(); void clear_cache();
private: private:
void cuda_free(void* buf);
CudaAllocator(); CudaAllocator();
friend CudaAllocator& allocator(); friend CudaAllocator& allocator();
std::mutex worker_mutex_;
std::unique_ptr<Worker> worker_;
std::set<std::thread::id> allowed_threads_;
std::mutex mutex_; std::mutex mutex_;
size_t memory_limit_; size_t memory_limit_;
size_t max_pool_size_; size_t max_pool_size_;

View File

@@ -315,7 +315,6 @@ void CommandEncoder::synchronize() {
auto p = std::make_shared<std::promise<void>>(); auto p = std::make_shared<std::promise<void>>();
std::future<void> f = p->get_future(); std::future<void> f = p->get_future();
add_completed_handler([p = std::move(p)]() { p->set_value(); }); add_completed_handler([p = std::move(p)]() { p->set_value(); });
worker_.end_batch();
commit(); commit();
f.wait(); f.wait();
} }

View File

@@ -19,8 +19,6 @@ void new_stream(Stream s) {
cudaFree(nullptr); cudaFree(nullptr);
// Ensure the static stream objects get created. // Ensure the static stream objects get created.
cu::get_command_encoder(s); cu::get_command_encoder(s);
// The main thread is safe to free buffers.
cu::allocator().register_this_thread();
} }
void eval(array& arr) { void eval(array& arr) {

View File

@@ -110,24 +110,26 @@ __global__ void event_signal_kernel(SharedEvent::Atomic* ac, uint64_t value) {
event_signal(ac, value); event_signal(ac, value);
} }
SharedEvent::Atomic* to_atomic(std::shared_ptr<Buffer> buf) {
return static_cast<SharedEvent::Atomic*>(buf->raw_ptr());
}
SharedEvent::SharedEvent() { SharedEvent::SharedEvent() {
// Allocate cuda::atomic on managed memory. buf_ = std::shared_ptr<Buffer>(
Atomic* ac; new Buffer{allocator().malloc(sizeof(Atomic))}, [](Buffer* ptr) {
CHECK_CUDA_ERROR(cudaMallocManaged(&ac, sizeof(Atomic))); allocator().free(*ptr);
new (ac) Atomic(0); delete ptr;
ac_ = std::shared_ptr<Atomic>(ac, [](Atomic* ptr) { });
ptr->~Atomic(); *static_cast<uint64_t*>(buf_->raw_ptr()) = 0;
allocator().cuda_free(ptr);
});
} }
void SharedEvent::wait(uint64_t value) { void SharedEvent::wait(uint64_t value) {
nvtx3::scoped_range r("cu::SharedEvent::wait"); nvtx3::scoped_range r("cu::SharedEvent::wait");
event_wait(ac_.get(), value); event_wait(to_atomic(buf_), value);
} }
void SharedEvent::wait(cudaStream_t stream, uint64_t value) { void SharedEvent::wait(cudaStream_t stream, uint64_t value) {
event_wait_kernel<<<1, 1, 0, stream>>>(ac_.get(), value); event_wait_kernel<<<1, 1, 0, stream>>>(to_atomic(buf_), value);
} }
void SharedEvent::wait(Stream s, uint64_t value) { void SharedEvent::wait(Stream s, uint64_t value) {
@@ -138,17 +140,17 @@ void SharedEvent::wait(Stream s, uint64_t value) {
auto& encoder = get_command_encoder(s); auto& encoder = get_command_encoder(s);
encoder.commit(); encoder.commit();
wait(encoder.stream(), value); wait(encoder.stream(), value);
encoder.add_completed_handler([ac = ac_]() {}); encoder.add_completed_handler([buf = buf_]() {});
} }
} }
void SharedEvent::signal(uint64_t value) { void SharedEvent::signal(uint64_t value) {
nvtx3::scoped_range r("cu::SharedEvent::signal"); nvtx3::scoped_range r("cu::SharedEvent::signal");
event_signal(ac_.get(), value); event_signal(to_atomic(buf_), value);
} }
void SharedEvent::signal(cudaStream_t stream, uint64_t value) { void SharedEvent::signal(cudaStream_t stream, uint64_t value) {
event_signal_kernel<<<1, 1, 0, stream>>>(ac_.get(), value); event_signal_kernel<<<1, 1, 0, stream>>>(to_atomic(buf_), value);
} }
void SharedEvent::signal(Stream s, uint64_t value) { void SharedEvent::signal(Stream s, uint64_t value) {
@@ -162,18 +164,18 @@ void SharedEvent::signal(Stream s, uint64_t value) {
auto& encoder = get_command_encoder(s); auto& encoder = get_command_encoder(s);
encoder.commit(); encoder.commit();
signal(encoder.stream(), value); signal(encoder.stream(), value);
encoder.add_completed_handler([ac = ac_]() {}); encoder.add_completed_handler([buf = buf_]() {});
} }
} }
bool SharedEvent::is_signaled(uint64_t value) const { bool SharedEvent::is_signaled(uint64_t value) const {
nvtx3::scoped_range r("cu::SharedEvent::is_signaled"); nvtx3::scoped_range r("cu::SharedEvent::is_signaled");
return ac_->load() >= value; return to_atomic(buf_)->load() >= value;
} }
uint64_t SharedEvent::value() const { uint64_t SharedEvent::value() const {
nvtx3::scoped_range r("cu::SharedEvent::value"); nvtx3::scoped_range r("cu::SharedEvent::value");
return ac_->load(); return to_atomic(buf_)->load();
} }
} // namespace cu } // namespace cu

View File

@@ -2,6 +2,7 @@
#pragma once #pragma once
#include "mlx/allocator.h"
#include "mlx/stream.h" #include "mlx/stream.h"
#include <cuda_runtime.h> #include <cuda_runtime.h>
@@ -55,12 +56,8 @@ class SharedEvent {
bool is_signaled(uint64_t value) const; bool is_signaled(uint64_t value) const;
uint64_t value() const; uint64_t value() const;
const std::shared_ptr<Atomic>& atomic() const {
return ac_;
}
private: private:
std::shared_ptr<Atomic> ac_; std::shared_ptr<mlx::core::allocator::Buffer> buf_;
}; };
} // namespace mlx::core::cu } // namespace mlx::core::cu

View File

@@ -1,7 +1,6 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/worker.h" #include "mlx/backend/cuda/worker.h"
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
namespace mlx::core::cu { namespace mlx::core::cu {
@@ -23,13 +22,6 @@ void Worker::add_task(std::function<void()> task) {
pending_tasks_.push_back(std::move(task)); pending_tasks_.push_back(std::move(task));
} }
void Worker::consume_in_this_thread() {
for (auto& task : pending_tasks_) {
task();
}
pending_tasks_.clear();
}
void Worker::end_batch() { void Worker::end_batch() {
batch_++; batch_++;
{ {
@@ -39,14 +31,6 @@ void Worker::end_batch() {
uncommited_batches_++; uncommited_batches_++;
} }
void Worker::commit() {
if (uncommited_batches_ == 0) {
return;
}
uncommited_batches_ = 0;
worker_event_.signal(batch_);
}
void Worker::commit(cudaStream_t stream) { void Worker::commit(cudaStream_t stream) {
if (uncommited_batches_ == 0) { if (uncommited_batches_ == 0) {
return; return;
@@ -60,9 +44,6 @@ void Worker::commit(cudaStream_t stream) {
} }
void Worker::thread_fn() { void Worker::thread_fn() {
// The worker thread is safe to free buffers.
allocator().register_this_thread();
while (!stop_) { while (!stop_) {
uint64_t batch = worker_event_.value(); uint64_t batch = worker_event_.value();
Tasks tasks; Tasks tasks;

View File

@@ -24,15 +24,9 @@ class Worker {
// Add a pending |task| that will run when consumed or commited. // Add a pending |task| that will run when consumed or commited.
void add_task(std::function<void()> task); void add_task(std::function<void()> task);
// Run pending tasks immediately in current thread.
void consume_in_this_thread();
// Put pending tasks in a batch. // Put pending tasks in a batch.
void end_batch(); void end_batch();
// Inform worker thread to run current batches now.
void commit();
// Inform worker thread to run current batches after kernels in |stream| // Inform worker thread to run current batches after kernels in |stream|
// finish running. // finish running.
void commit(cudaStream_t stream); void commit(cudaStream_t stream);