diff --git a/mlx/backend/cuda/allocator.cpp b/mlx/backend/cuda/allocator.cpp index 66e8c5c660..93bf48542d 100644 --- a/mlx/backend/cuda/allocator.cpp +++ b/mlx/backend/cuda/allocator.cpp @@ -2,7 +2,6 @@ #include "mlx/backend/cuda/allocator.h" #include "mlx/backend/cuda/utils.h" -#include "mlx/backend/cuda/worker.h" #include "mlx/utils.h" #include @@ -25,52 +24,58 @@ constexpr int small_block_size = 8; constexpr int small_pool_size = 4 * page_size; SmallSizePool::SmallSizePool() { - CHECK_CUDA_ERROR(cudaMallocManaged(&buffer_, small_pool_size)); - end_ = reinterpret_cast( - reinterpret_cast(buffer_) + small_pool_size); - next_free_ = reinterpret_cast(buffer_); - auto num_blocks = small_pool_size / small_block_size; + buffer_ = new Block[num_blocks]; + + next_free_ = buffer_; + + CHECK_CUDA_ERROR(cudaMallocManaged(&data_, small_pool_size)); + CHECK_CUDA_ERROR( + cudaMemAdvise(data_, small_pool_size, cudaMemAdviseSetReadMostly, 0)); + auto curr = next_free_; - for (size_t i = 0; i < num_blocks - 1; ++i) { - curr->next = reinterpret_cast( - reinterpret_cast(buffer_) + (i + 1) * small_block_size); + for (size_t i = 1; i < num_blocks; ++i) { + curr->next = buffer_ + i; curr = curr->next; } curr->next = nullptr; } SmallSizePool::~SmallSizePool() { - CHECK_CUDA_ERROR(cudaFree(buffer_)); + CHECK_CUDA_ERROR(cudaFree(data_)); + delete[] buffer_; } -void* SmallSizePool::malloc() { +CudaBuffer* SmallSizePool::malloc() { if (next_free_ == nullptr) { return nullptr; } Block* b = next_free_; + uint64_t i = next_free_ - buffer_; next_free_ = next_free_->next; - return static_cast(b); + b->buf.data = static_cast(data_) + i * small_block_size; + b->buf.size = small_block_size; + return &b->buf; } -void SmallSizePool::free(void* p) { - auto b = static_cast(p); +void SmallSizePool::free(CudaBuffer* buf) { + auto b = reinterpret_cast(buf); b->next = next_free_; next_free_ = b; } -bool SmallSizePool::in_pool(void* p) { - return (p >= buffer_) && (p < end_); +bool SmallSizePool::in_pool(CudaBuffer* buf) { + constexpr int num_blocks = (small_pool_size / small_block_size); + auto b = reinterpret_cast(buf); + int64_t block_num = b - buffer_; + return block_num >= 0 && block_num < num_blocks; } CudaAllocator::CudaAllocator() : buffer_cache_( page_size, [](CudaBuffer* buf) { return buf->size; }, - [this](CudaBuffer* buf) { - cuda_free(buf->data); - delete buf; - }) { + [this](CudaBuffer* buf) { cuda_free(buf); }) { // TODO: Set memory limit for multi-device. size_t free, total; CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total)); @@ -92,28 +97,26 @@ Buffer CudaAllocator::malloc(size_t size) { CudaBuffer* buf = buffer_cache_.reuse_from_cache(size); if (!buf) { - // If we have a lot of memory pressure or are over the maximum cache size, - // try to reclaim memory from the cache. - size_t mem_required = get_active_memory() + get_cache_memory() + size; - if (mem_required >= memory_limit_) { - buffer_cache_.release_cached_buffers(mem_required - memory_limit_); + // If we have a lot of memory pressure try to reclaim memory from the cache. + int64_t mem_to_free = + get_active_memory() + get_cache_memory() + size - memory_limit_; + if (mem_to_free > 0) { + buffer_cache_.release_cached_buffers(mem_to_free); } - lock.unlock(); - buf = new CudaBuffer{nullptr, size}; - // Try the scalar pool first if (size <= small_block_size) { - buf->data = scalar_pool_.malloc(); + buf = scalar_pool_.malloc(); } - if (!buf->data) { + lock.unlock(); + if (!buf) { + buf = new CudaBuffer{nullptr, size}; cudaError_t err = cudaMallocManaged(&buf->data, size); if (err != cudaSuccess && err != cudaErrorMemoryAllocation) { throw std::runtime_error(fmt::format( "cudaMallocManaged failed: {}.", cudaGetErrorString(err))); } } - lock.lock(); } active_memory_ += size; @@ -123,7 +126,6 @@ Buffer CudaAllocator::malloc(size_t size) { if (get_cache_memory() > max_pool_size_) { buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_); } - return Buffer{buf}; } @@ -138,9 +140,7 @@ void CudaAllocator::free(Buffer buffer) { if (get_cache_memory() < max_pool_size_) { buffer_cache_.recycle_to_cache(buf); } else { - lock.unlock(); - cuda_free(buf->data); - delete buf; + cuda_free(buf); } } @@ -152,30 +152,13 @@ size_t CudaAllocator::size(Buffer buffer) const { return buf->size; } -void CudaAllocator::register_this_thread() { - std::lock_guard lock(worker_mutex_); - allowed_threads_.insert(std::this_thread::get_id()); -} - -void CudaAllocator::cuda_free(void* buf) { - // If cuda_free() is called from a unregistered thread, reschedule the call to - // worker. - { - std::lock_guard lock(worker_mutex_); - if (allowed_threads_.count(std::this_thread::get_id()) == 0) { - if (!worker_) { - worker_.reset(new Worker); - } - worker_->add_task([this, buf]() { this->cuda_free(buf); }); - worker_->end_batch(); - worker_->commit(); - return; - } - } +// This must be called with mutex_ aquired +void CudaAllocator::cuda_free(CudaBuffer* buf) { if (scalar_pool_.in_pool(buf)) { scalar_pool_.free(buf); } else { - cudaFree(buf); + cudaFree(buf->data); + delete buf; } } diff --git a/mlx/backend/cuda/allocator.h b/mlx/backend/cuda/allocator.h index f7474dda6d..81b3dde593 100644 --- a/mlx/backend/cuda/allocator.h +++ b/mlx/backend/cuda/allocator.h @@ -7,13 +7,10 @@ #include #include -#include #include namespace mlx::core::cu { -class Worker; - using allocator::Buffer; // Stores cuda-managed unified memory. @@ -24,13 +21,14 @@ struct CudaBuffer { class SmallSizePool { private: - struct Block { + union Block { Block* next; + CudaBuffer buf; }; - void* buffer_{nullptr}; + Block* buffer_{nullptr}; + void* data_{nullptr}; Block* next_free_{nullptr}; - void* end_{nullptr}; public: SmallSizePool(); @@ -39,9 +37,9 @@ class SmallSizePool { SmallSizePool(const SmallSizePool&) = delete; SmallSizePool& operator=(const SmallSizePool&) = delete; - void* malloc(); - void free(void* p); - bool in_pool(void* p); + CudaBuffer* malloc(); + void free(CudaBuffer* buf); + bool in_pool(CudaBuffer* buf); }; class CudaAllocator : public allocator::Allocator { @@ -50,15 +48,6 @@ class CudaAllocator : public allocator::Allocator { void free(Buffer buffer) override; size_t size(Buffer buffer) const override; - // Register current thread as safe to free buffers. - // In cuda freeing a buffer implicitly synchronizes stream, and for threads - // that may be waited by gpu stream (for example cpu stream threads), freeing - // buffers there would result in dead lock. - void register_this_thread(); - - // Call cudaFree in the safe thread. - void cuda_free(void* buf); - size_t get_active_memory() const; size_t get_peak_memory() const; void reset_peak_memory(); @@ -69,13 +58,11 @@ class CudaAllocator : public allocator::Allocator { void clear_cache(); private: + void cuda_free(CudaBuffer* buf); + CudaAllocator(); friend CudaAllocator& allocator(); - std::mutex worker_mutex_; - std::unique_ptr worker_; - std::set allowed_threads_; - std::mutex mutex_; size_t memory_limit_; size_t max_pool_size_; diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index 3362315283..366cdf8269 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -306,7 +306,6 @@ void CommandEncoder::commit() { } // Put completion handlers in a batch. - worker_.end_batch(); worker_.commit(stream_); } @@ -315,7 +314,6 @@ void CommandEncoder::synchronize() { auto p = std::make_shared>(); std::future f = p->get_future(); add_completed_handler([p = std::move(p)]() { p->set_value(); }); - worker_.end_batch(); commit(); f.wait(); } diff --git a/mlx/backend/cuda/eval.cpp b/mlx/backend/cuda/eval.cpp index 40beb12d2e..0e1477e950 100644 --- a/mlx/backend/cuda/eval.cpp +++ b/mlx/backend/cuda/eval.cpp @@ -19,8 +19,6 @@ void new_stream(Stream s) { cudaFree(nullptr); // Ensure the static stream objects get created. cu::get_command_encoder(s); - // The main thread is safe to free buffers. - cu::allocator().register_this_thread(); } void eval(array& arr) { diff --git a/mlx/backend/cuda/event.cu b/mlx/backend/cuda/event.cu index f51d2f2e35..c9c6b36592 100644 --- a/mlx/backend/cuda/event.cu +++ b/mlx/backend/cuda/event.cu @@ -110,24 +110,26 @@ __global__ void event_signal_kernel(SharedEvent::Atomic* ac, uint64_t value) { event_signal(ac, value); } +SharedEvent::Atomic* to_atomic(std::shared_ptr buf) { + return static_cast(buf->raw_ptr()); +} + SharedEvent::SharedEvent() { - // Allocate cuda::atomic on managed memory. - Atomic* ac; - CHECK_CUDA_ERROR(cudaMallocManaged(&ac, sizeof(Atomic))); - new (ac) Atomic(0); - ac_ = std::shared_ptr(ac, [](Atomic* ptr) { - ptr->~Atomic(); - allocator().cuda_free(ptr); - }); + buf_ = std::shared_ptr( + new Buffer{allocator().malloc(sizeof(Atomic))}, [](Buffer* ptr) { + allocator().free(*ptr); + delete ptr; + }); + *static_cast(buf_->raw_ptr()) = 0; } void SharedEvent::wait(uint64_t value) { nvtx3::scoped_range r("cu::SharedEvent::wait"); - event_wait(ac_.get(), value); + event_wait(to_atomic(buf_), value); } void SharedEvent::wait(cudaStream_t stream, uint64_t value) { - event_wait_kernel<<<1, 1, 0, stream>>>(ac_.get(), value); + event_wait_kernel<<<1, 1, 0, stream>>>(to_atomic(buf_), value); } void SharedEvent::wait(Stream s, uint64_t value) { @@ -138,17 +140,17 @@ void SharedEvent::wait(Stream s, uint64_t value) { auto& encoder = get_command_encoder(s); encoder.commit(); wait(encoder.stream(), value); - encoder.add_completed_handler([ac = ac_]() {}); + encoder.add_completed_handler([buf = buf_]() {}); } } void SharedEvent::signal(uint64_t value) { nvtx3::scoped_range r("cu::SharedEvent::signal"); - event_signal(ac_.get(), value); + event_signal(to_atomic(buf_), value); } void SharedEvent::signal(cudaStream_t stream, uint64_t value) { - event_signal_kernel<<<1, 1, 0, stream>>>(ac_.get(), value); + event_signal_kernel<<<1, 1, 0, stream>>>(to_atomic(buf_), value); } void SharedEvent::signal(Stream s, uint64_t value) { @@ -162,18 +164,18 @@ void SharedEvent::signal(Stream s, uint64_t value) { auto& encoder = get_command_encoder(s); encoder.commit(); signal(encoder.stream(), value); - encoder.add_completed_handler([ac = ac_]() {}); + encoder.add_completed_handler([buf = buf_]() {}); } } bool SharedEvent::is_signaled(uint64_t value) const { nvtx3::scoped_range r("cu::SharedEvent::is_signaled"); - return ac_->load() >= value; + return to_atomic(buf_)->load() >= value; } uint64_t SharedEvent::value() const { nvtx3::scoped_range r("cu::SharedEvent::value"); - return ac_->load(); + return to_atomic(buf_)->load(); } } // namespace cu diff --git a/mlx/backend/cuda/event.h b/mlx/backend/cuda/event.h index 4b56e2e3b8..3ef9786f32 100644 --- a/mlx/backend/cuda/event.h +++ b/mlx/backend/cuda/event.h @@ -2,6 +2,7 @@ #pragma once +#include "mlx/allocator.h" #include "mlx/stream.h" #include @@ -55,12 +56,8 @@ class SharedEvent { bool is_signaled(uint64_t value) const; uint64_t value() const; - const std::shared_ptr& atomic() const { - return ac_; - } - private: - std::shared_ptr ac_; + std::shared_ptr buf_; }; } // namespace mlx::core::cu diff --git a/mlx/backend/cuda/worker.cpp b/mlx/backend/cuda/worker.cpp index 3b35c830b4..ce211367ce 100644 --- a/mlx/backend/cuda/worker.cpp +++ b/mlx/backend/cuda/worker.cpp @@ -1,7 +1,6 @@ // Copyright © 2025 Apple Inc. #include "mlx/backend/cuda/worker.h" -#include "mlx/backend/cuda/allocator.h" #include "mlx/backend/cuda/device.h" namespace mlx::core::cu { @@ -12,10 +11,10 @@ Worker::Worker() Worker::~Worker() { { - std::lock_guard lock(worker_mutex_); + std::lock_guard lock(mtx_); stop_ = true; } - worker_event_.signal(batch_ + 1); + cond_.notify_one(); worker_.join(); } @@ -23,53 +22,41 @@ void Worker::add_task(std::function task) { pending_tasks_.push_back(std::move(task)); } -void Worker::consume_in_this_thread() { - for (auto& task : pending_tasks_) { - task(); - } - pending_tasks_.clear(); -} - -void Worker::end_batch() { - batch_++; +void Worker::signal(void* data) { + auto w = static_cast(data); { - std::lock_guard lock(worker_mutex_); - worker_tasks_[batch_] = std::move(pending_tasks_); + std::lock_guard lock(w->mtx_); + w->signaled_batch_++; } - uncommited_batches_++; -} - -void Worker::commit() { - if (uncommited_batches_ == 0) { - return; - } - uncommited_batches_ = 0; - worker_event_.signal(batch_); + w->cond_.notify_one(); } void Worker::commit(cudaStream_t stream) { - if (uncommited_batches_ == 0) { + // Move pending tasks into tasks + if (pending_tasks_.empty()) { return; } - uncommited_batches_ = 0; - // Signal the |worker_event_| in |signal_stream_| after the kernels in - // |stream_| finish running. + { + std::lock_guard lock(mtx_); + // Move pending tasks into ready tasks + worker_tasks_[++committed_batch_] = std::move(pending_tasks_); + } signal_event_.record(stream); signal_event_.wait(signal_stream_); - worker_event_.signal(signal_stream_, batch_); + cudaLaunchHostFunc(signal_stream_, signal, this); } void Worker::thread_fn() { - // The worker thread is safe to free buffers. - allocator().register_this_thread(); - while (!stop_) { - uint64_t batch = worker_event_.value(); + uint64_t current_batch = 0; Tasks tasks; { - std::lock_guard lock(worker_mutex_); - // Move tasks in signaled batches. - auto end = worker_tasks_.upper_bound(batch); + std::unique_lock lk(mtx_); + cond_.wait(lk, [this, ¤t_batch] { + return this->signaled_batch_ > current_batch || this->stop_; + }); + current_batch = signaled_batch_; + auto end = worker_tasks_.upper_bound(current_batch); for (auto it = worker_tasks_.begin(); it != end; ++it) { if (tasks.empty()) { tasks = std::move(it->second); @@ -85,7 +72,6 @@ void Worker::thread_fn() { auto task = std::move(tasks[i]); task(); } - worker_event_.wait(batch + 1); } } diff --git a/mlx/backend/cuda/worker.h b/mlx/backend/cuda/worker.h index d28e22e952..df6647e2b7 100644 --- a/mlx/backend/cuda/worker.h +++ b/mlx/backend/cuda/worker.h @@ -5,6 +5,7 @@ #include "mlx/backend/cuda/event.h" #include "mlx/backend/cuda/utils.h" +#include #include #include #include @@ -24,38 +25,24 @@ class Worker { // Add a pending |task| that will run when consumed or commited. void add_task(std::function task); - // Run pending tasks immediately in current thread. - void consume_in_this_thread(); - - // Put pending tasks in a batch. - void end_batch(); - - // Inform worker thread to run current batches now. - void commit(); - // Inform worker thread to run current batches after kernels in |stream| // finish running. void commit(cudaStream_t stream); - // Return how many batches have been added but not committed yet. - size_t uncommited_batches() const { - return uncommited_batches_; - } - private: - void thread_fn(); + static void signal(void*); - uint64_t batch_{0}; - size_t uncommited_batches_{0}; + void thread_fn(); + std::mutex mtx_; + std::condition_variable cond_; + + uint64_t committed_batch_{0}; + uint64_t signaled_batch_{0}; // Cuda stream and event for signaling kernel completion. CudaStream signal_stream_; CudaEvent signal_event_; - // Worker thread. - SharedEvent worker_event_; - std::thread worker_; - std::mutex worker_mutex_; bool stop_{false}; // Tasks are put in |pending_tasks_| first, and then moved to @@ -63,6 +50,7 @@ class Worker { using Tasks = std::vector>; Tasks pending_tasks_; std::map worker_tasks_; + std::thread worker_; }; } // namespace mlx::core::cu diff --git a/mlx/backend/metal/allocator.cpp b/mlx/backend/metal/allocator.cpp index dd61897326..8eb70bcbca 100644 --- a/mlx/backend/metal/allocator.cpp +++ b/mlx/backend/metal/allocator.cpp @@ -128,8 +128,7 @@ Buffer MetalAllocator::malloc(size_t size) { auto pool = metal::new_scoped_memory_pool(); - // If we have a lot of memory pressure or are over the maximum cache size, - // try to reclaim memory from the cache + // If we have a lot of memory pressure try to reclaim memory from the cache if (mem_required >= gc_limit_ || num_resources_ >= resource_limit_) { num_resources_ -= buffer_cache_.release_cached_buffers(mem_required - gc_limit_);