diff --git a/mlx/backend/common/buffer_cache.h b/mlx/backend/common/buffer_cache.h new file mode 100644 index 000000000..92b20f222 --- /dev/null +++ b/mlx/backend/common/buffer_cache.h @@ -0,0 +1,157 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include +#include + +namespace mlx::core { + +template +class BufferCache { + public: + BufferCache( + size_t page_size, + std::function get_size, + std::function free) + : page_size_(page_size), + get_size_(std::move(get_size)), + free_(std::move(free)) {} + + ~BufferCache() { + clear(); + } + + BufferCache(const BufferCache&) = delete; + BufferCache& operator=(const BufferCache&) = delete; + + T* reuse_from_cache(size_t size) { + // Find the closest buffer in pool. + auto it = buffer_pool_.lower_bound(size); + if (it == buffer_pool_.end() || + it->first >= std::min(2 * size, size + 2 * page_size_)) { + return nullptr; + } + + // Collect from the cache. + T* buf = it->second->buf; + pool_size_ -= it->first; + + // Remove from record. + remove_from_list(it->second); + buffer_pool_.erase(it); + return buf; + } + + void recycle_to_cache(T* buf) { + assert(buf); + // Add to cache. + BufferHolder* bh = new BufferHolder(buf); + add_at_head(bh); + size_t size = get_size_(buf); + pool_size_ += size; + buffer_pool_.emplace(size, bh); + } + + int release_cached_buffers(size_t min_bytes_to_free) { + if (min_bytes_to_free >= 0.9 * pool_size_) { + return clear(); + } else { + int n_release = 0; + size_t total_bytes_freed = 0; + + while (tail_ && (total_bytes_freed < min_bytes_to_free)) { + // Release buffer. + size_t size = get_size_(tail_->buf); + total_bytes_freed += size; + free_(tail_->buf); + n_release++; + + // Remove from record. + auto its = buffer_pool_.equal_range(size); + auto it = std::find_if(its.first, its.second, [this](const auto& el) { + return el.second == tail_; + }); + assert(it != buffer_pool_.end()); + buffer_pool_.erase(it); + remove_from_list(tail_); + } + + pool_size_ -= total_bytes_freed; + return n_release; + } + } + + int clear() { + int n_release = 0; + for (auto& [size, holder] : buffer_pool_) { + free_(holder->buf); + n_release++; + delete holder; + } + buffer_pool_.clear(); + pool_size_ = 0; + head_ = nullptr; + tail_ = nullptr; + return n_release; + } + + size_t cache_size() const { + return pool_size_; + } + + size_t page_size() const { + return page_size_; + } + + private: + struct BufferHolder { + public: + explicit BufferHolder(T* buf_) : buf(buf_) {} + + BufferHolder* prev{nullptr}; + BufferHolder* next{nullptr}; + T* buf; + }; + + void add_at_head(BufferHolder* to_add) { + if (!head_) { + head_ = to_add; + tail_ = to_add; + } else { + head_->prev = to_add; + to_add->next = head_; + head_ = to_add; + } + } + + void remove_from_list(BufferHolder* to_remove) { + if (to_remove->prev && to_remove->next) { // if middle + to_remove->prev->next = to_remove->next; + to_remove->next->prev = to_remove->prev; + } else if (to_remove->prev && to_remove == tail_) { // if tail + tail_ = to_remove->prev; + tail_->next = nullptr; + } else if (to_remove == head_ && to_remove->next) { // if head + head_ = to_remove->next; + head_->prev = nullptr; + } else if (to_remove == head_ && to_remove == tail_) { // if only element + head_ = nullptr; + tail_ = nullptr; + } + + delete to_remove; + } + + std::multimap buffer_pool_; + BufferHolder* head_{nullptr}; + BufferHolder* tail_{nullptr}; + size_t pool_size_{0}; + + const size_t page_size_; + std::function get_size_; + std::function free_; +}; + +} // namespace mlx::core diff --git a/mlx/backend/cuda/allocator.cpp b/mlx/backend/cuda/allocator.cpp index 203534e21..86af3a774 100644 --- a/mlx/backend/cuda/allocator.cpp +++ b/mlx/backend/cuda/allocator.cpp @@ -6,6 +6,7 @@ #include #include +#include #include @@ -13,24 +14,47 @@ namespace mlx::core { namespace cu { -CudaAllocator::CudaAllocator() { +CudaAllocator::CudaAllocator() + : buffer_cache_( + getpagesize(), + [](CudaBuffer* buf) { return buf->size; }, + [this](CudaBuffer* buf) { cuda_free(buf); }) { // TODO: Set memory limit for multi-device. size_t free, total; CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total)); memory_limit_ = total * 0.8; + max_pool_size_ = memory_limit_; } Buffer CudaAllocator::malloc(size_t size) { - // TODO: Check memory limit. - auto* buf = new CudaBuffer{nullptr, size}; - cudaError_t err = cudaMallocManaged(&buf->data, size); - if (err != cudaSuccess && err != cudaErrorMemoryAllocation) { - throw std::runtime_error( - fmt::format("cudaMallocManaged failed: {}.", cudaGetErrorString(err))); + // Find available buffer from cache. + std::unique_lock lock(mutex_); + CudaBuffer* buf = buffer_cache_.reuse_from_cache(size); + if (!buf) { + // If we have a lot of memory pressure or are over the maximum cache size, + // try to reclaim memory from the cache. + size_t mem_required = get_active_memory() + get_cache_memory() + size; + if (mem_required >= memory_limit_) { + buffer_cache_.release_cached_buffers(mem_required - memory_limit_); + } + + lock.unlock(); + buf = new CudaBuffer{nullptr, size}; + cudaError_t err = cudaMallocManaged(&buf->data, size); + if (err != cudaSuccess && err != cudaErrorMemoryAllocation) { + throw std::runtime_error(fmt::format( + "cudaMallocManaged failed: {}.", cudaGetErrorString(err))); + } + lock.lock(); } - std::lock_guard lock(mutex_); active_memory_ += size; peak_memory_ = std::max(active_memory_, peak_memory_); + + // Maintain the cache below the requested limit. + if (get_cache_memory() > max_pool_size_) { + buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_); + } + return Buffer{buf}; } @@ -40,26 +64,14 @@ void CudaAllocator::free(Buffer buffer) { return; } - // If free() is called from a unregistered thread, reschedule the call to - // worker. - { - std::lock_guard lock(worker_mutex_); - if (allowed_threads_.count(std::this_thread::get_id()) == 0) { - if (!worker_) { - worker_.reset(new Worker); - } - worker_->add_task([buffer]() { allocator().free(buffer); }); - worker_->end_batch(); - worker_->commit(); - return; - } + std::unique_lock lock(mutex_); + active_memory_ -= buf->size; + if (get_cache_memory() < max_pool_size_) { + buffer_cache_.recycle_to_cache(buf); + } else { + lock.unlock(); + cuda_free(buf); } - - size_t size = buf->size; - cudaFree(buf->data); - delete buf; - std::lock_guard lock(mutex_); - active_memory_ -= size; } size_t CudaAllocator::size(Buffer buffer) const { @@ -98,6 +110,41 @@ size_t CudaAllocator::set_memory_limit(size_t limit) { return limit; } +size_t CudaAllocator::get_cache_memory() const { + return buffer_cache_.cache_size(); +} + +size_t CudaAllocator::set_cache_limit(size_t limit) { + std::lock_guard lk(mutex_); + std::swap(limit, max_pool_size_); + return limit; +} + +void CudaAllocator::clear_cache() { + std::lock_guard lk(mutex_); + buffer_cache_.clear(); +} + +void CudaAllocator::cuda_free(CudaBuffer* buf) { + // If cuda_free() is called from a unregistered thread, reschedule the call to + // worker. + { + std::lock_guard lock(worker_mutex_); + if (allowed_threads_.count(std::this_thread::get_id()) == 0) { + if (!worker_) { + worker_.reset(new Worker); + } + worker_->add_task([this, buf]() { this->cuda_free(buf); }); + worker_->end_batch(); + worker_->commit(); + return; + } + } + + cudaFree(buf->data); + delete buf; +} + CudaAllocator& allocator() { // By creating the |allocator_| on heap, the destructor of CudaAllocator // will not be called on exit and buffers in the cache will be leaked. This @@ -138,17 +185,19 @@ size_t set_memory_limit(size_t limit) { size_t get_memory_limit() { return cu::allocator().get_memory_limit(); } - -// TODO: Implement buffer cache. size_t get_cache_memory() { - return 0; + return cu::allocator().get_cache_memory(); } -size_t set_cache_limit(size_t) { - return 0; +size_t set_cache_limit(size_t limit) { + return cu::allocator().set_cache_limit(limit); } +void clear_cache() { + cu::allocator().clear_cache(); +} + +// Not supported in CUDA. size_t set_wired_limit(size_t) { return 0; } -void clear_cache() {} } // namespace mlx::core diff --git a/mlx/backend/cuda/allocator.h b/mlx/backend/cuda/allocator.h index 6c418ee7e..fe3755121 100644 --- a/mlx/backend/cuda/allocator.h +++ b/mlx/backend/cuda/allocator.h @@ -3,6 +3,7 @@ #pragma once #include "mlx/allocator.h" +#include "mlx/backend/common/buffer_cache.h" #include #include @@ -38,17 +39,24 @@ class CudaAllocator : public allocator::Allocator { void reset_peak_memory(); size_t get_memory_limit(); size_t set_memory_limit(size_t limit); + size_t get_cache_memory() const; + size_t set_cache_limit(size_t limit); + void clear_cache(); private: CudaAllocator(); friend CudaAllocator& allocator(); + void cuda_free(CudaBuffer* buf); + std::mutex worker_mutex_; std::unique_ptr worker_; std::set allowed_threads_; std::mutex mutex_; size_t memory_limit_; + size_t max_pool_size_; + BufferCache buffer_cache_; size_t active_memory_{0}; size_t peak_memory_{0}; }; diff --git a/mlx/backend/metal/allocator.cpp b/mlx/backend/metal/allocator.cpp index 5d8bd90d5..dd6189732 100644 --- a/mlx/backend/metal/allocator.cpp +++ b/mlx/backend/metal/allocator.cpp @@ -30,141 +30,18 @@ void* Buffer::raw_ptr() { namespace metal { -namespace { - -BufferCache::BufferCache(ResidencySet& residency_set) - : head_(nullptr), - tail_(nullptr), - pool_size_(0), - residency_set_(residency_set) {} - -BufferCache::~BufferCache() { - auto pool = metal::new_scoped_memory_pool(); - clear(); -} - -int BufferCache::clear() { - int n_release = 0; - for (auto& [size, holder] : buffer_pool_) { - if (holder->buf) { - if (!holder->buf->heap()) { - residency_set_.erase(holder->buf); - } - holder->buf->release(); - n_release++; - } - delete holder; - } - buffer_pool_.clear(); - pool_size_ = 0; - head_ = nullptr; - tail_ = nullptr; - return n_release; -} - -MTL::Buffer* BufferCache::reuse_from_cache(size_t size) { - // Find the closest buffer in pool - MTL::Buffer* pbuf = nullptr; - - auto it = buffer_pool_.lower_bound(size); - - // Make sure we use most of the available memory - while (!pbuf && it != buffer_pool_.end() && - it->first < std::min(2 * size, size + 2 * vm_page_size)) { - // Collect from the cache - pbuf = it->second->buf; - - // Remove from cache - remove_from_list(it->second); - delete it->second; - it = buffer_pool_.erase(it); - } - - if (pbuf) { - pool_size_ -= pbuf->length(); - } - - return pbuf; -} - -void BufferCache::recycle_to_cache(MTL::Buffer* buf) { - // Add to cache - if (buf) { - BufferHolder* bh = new BufferHolder(buf); - add_at_head(bh); - pool_size_ += buf->length(); - buffer_pool_.insert({buf->length(), bh}); - } -} - -int BufferCache::release_cached_buffers(size_t min_bytes_to_free) { - if (min_bytes_to_free >= 0.9 * pool_size_) { - return clear(); - } else { - int n_release = 0; - size_t total_bytes_freed = 0; - - while (tail_ && (total_bytes_freed < min_bytes_to_free)) { - if (tail_->buf) { - total_bytes_freed += tail_->buf->length(); - if (!tail_->buf->heap()) { - residency_set_.erase(tail_->buf); - } - tail_->buf->release(); - tail_->buf = nullptr; - n_release++; - } - remove_from_list(tail_); - } - pool_size_ -= total_bytes_freed; - return n_release; - } -} - -void BufferCache::add_at_head(BufferCache::BufferHolder* to_add) { - if (!to_add) - return; - - if (!head_) { - head_ = to_add; - tail_ = to_add; - } else { - head_->prev = to_add; - to_add->next = head_; - head_ = to_add; - } -} - -void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) { - if (!to_remove) { - return; - } - - // If in the middle - if (to_remove->prev && to_remove->next) { - to_remove->prev->next = to_remove->next; - to_remove->next->prev = to_remove->prev; - } else if (to_remove->prev && to_remove == tail_) { // If tail - tail_ = to_remove->prev; - tail_->next = nullptr; - } else if (to_remove == head_ && to_remove->next) { // If head - head_ = to_remove->next; - head_->prev = nullptr; - } else if (to_remove == head_ && to_remove == tail_) { // If only element - head_ = nullptr; - tail_ = nullptr; - } - - to_remove->prev = nullptr; - to_remove->next = nullptr; -} - -} // namespace - MetalAllocator::MetalAllocator() : device_(device(mlx::core::Device::gpu).mtl_device()), residency_set_(device_), - buffer_cache_(residency_set_) { + buffer_cache_( + vm_page_size, + [](MTL::Buffer* buf) { return buf->length(); }, + [this](MTL::Buffer* buf) { + if (!buf->heap()) { + residency_set_.erase(buf); + } + buf->release(); + }) { auto pool = metal::new_scoped_memory_pool(); auto memsize = std::get(device_info().at("memory_size")); auto max_rec_size = @@ -193,6 +70,7 @@ MetalAllocator::~MetalAllocator() { if (heap_) { heap_->release(); } + buffer_cache_.clear(); } size_t MetalAllocator::set_cache_limit(size_t limit) { diff --git a/mlx/backend/metal/allocator.h b/mlx/backend/metal/allocator.h index 227b09e91..691317916 100644 --- a/mlx/backend/metal/allocator.h +++ b/mlx/backend/metal/allocator.h @@ -7,6 +7,7 @@ #include #include "mlx/allocator.h" +#include "mlx/backend/common/buffer_cache.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/resident.h" @@ -14,43 +15,6 @@ namespace mlx::core::metal { using allocator::Buffer; -namespace { - -class BufferCache { - public: - BufferCache(ResidencySet& residency_set); - ~BufferCache(); - - MTL::Buffer* reuse_from_cache(size_t size); - void recycle_to_cache(MTL::Buffer* buf); - int release_cached_buffers(size_t min_bytes_to_free); - size_t cache_size() { - return pool_size_; - } - int clear(); - - private: - struct BufferHolder { - public: - BufferHolder(MTL::Buffer* buf_) : buf(buf_), prev(nullptr), next(nullptr) {} - - BufferHolder* prev; - BufferHolder* next; - MTL::Buffer* buf; - }; - - void add_at_head(BufferHolder* to_add); - void remove_from_list(BufferHolder* to_remove); - - std::multimap buffer_pool_; - BufferHolder* head_; - BufferHolder* tail_; - size_t pool_size_; - ResidencySet& residency_set_; -}; - -} // namespace - class MetalAllocator : public allocator::Allocator { /** Allocator for Metal GPUs. */ public: @@ -90,7 +54,7 @@ class MetalAllocator : public allocator::Allocator { friend MetalAllocator& allocator(); // Caching allocator - BufferCache buffer_cache_; + BufferCache buffer_cache_; ResidencySet residency_set_;