From 82e4fbc1fd7ec7cd746e922f52cbcf7ed4e63475 Mon Sep 17 00:00:00 2001 From: Cheng Date: Sat, 24 May 2025 10:45:49 +0900 Subject: [PATCH] Move BufferCache out of allocator --- mlx/backend/common/buffer_cache.h | 163 ++++++++++++++++++++++++++++++ mlx/backend/metal/allocator.cpp | 142 ++------------------------ mlx/backend/metal/allocator.h | 40 +------- 3 files changed, 175 insertions(+), 170 deletions(-) create mode 100644 mlx/backend/common/buffer_cache.h diff --git a/mlx/backend/common/buffer_cache.h b/mlx/backend/common/buffer_cache.h new file mode 100644 index 000000000..12ecd80ce --- /dev/null +++ b/mlx/backend/common/buffer_cache.h @@ -0,0 +1,163 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include + +namespace mlx::core { + +template +class BufferCache { + public: + BufferCache( + size_t page_size, + std::function get_size, + std::function free) + : page_size_(page_size), + get_size_(std::move(get_size)), + free_(std::move(free)) {} + + ~BufferCache() { + clear(); + } + + BufferCache(const BufferCache&) = delete; + BufferCache& operator=(const BufferCache&) = delete; + + T* reuse_from_cache(size_t size) { + // Find the closest buffer in pool. + T* pbuf = nullptr; + + auto it = buffer_pool_.lower_bound(size); + + // Make sure we use most of the available memory. + while (!pbuf && it != buffer_pool_.end() && + it->first < std::min(2 * size, size + 2 * page_size_)) { + // Collect from the cache. + pbuf = it->second->buf; + + // Remove from cache. + remove_from_list(it->second); + it = buffer_pool_.erase(it); + } + + if (pbuf) { + pool_size_ -= get_size_(pbuf); + } + + return pbuf; + } + + void recycle_to_cache(T* buf) { + // Add to cache. + if (buf) { + BufferHolder* bh = new BufferHolder(buf); + add_at_head(bh); + size_t size = get_size_(buf); + pool_size_ += size; + buffer_pool_.insert({size, bh}); + } + } + + int release_cached_buffers(size_t min_bytes_to_free) { + if (min_bytes_to_free >= 0.9 * pool_size_) { + return clear(); + } else { + int n_release = 0; + size_t total_bytes_freed = 0; + + while (tail_ && (total_bytes_freed < min_bytes_to_free)) { + if (tail_->buf) { + total_bytes_freed += get_size_(tail_->buf); + free_(tail_->buf); + tail_->buf = nullptr; + n_release++; + } + remove_from_list(tail_); + for (auto it = buffer_pool_.begin(); it != buffer_pool_.end(); ++it) { + if (it->second == tail_) { + buffer_pool_.erase(it); + break; + } + } + } + pool_size_ -= total_bytes_freed; + return n_release; + } + } + + int clear() { + int n_release = 0; + for (auto& [size, holder] : buffer_pool_) { + if (holder->buf) { + free_(holder->buf); + n_release++; + } + delete holder; + } + buffer_pool_.clear(); + pool_size_ = 0; + head_ = nullptr; + tail_ = nullptr; + return n_release; + } + + size_t cache_size() const { + return pool_size_; + } + + size_t page_size() const { + return page_size_; + } + + private: + struct BufferHolder { + public: + explicit BufferHolder(T* buf_) : buf(buf_) {} + + BufferHolder* prev{nullptr}; + BufferHolder* next{nullptr}; + T* buf; + }; + + void add_at_head(BufferHolder* to_add) { + if (!head_) { + head_ = to_add; + tail_ = to_add; + } else { + head_->prev = to_add; + to_add->next = head_; + head_ = to_add; + } + } + + void remove_from_list(BufferHolder* to_remove) { + if (to_remove->prev && to_remove->next) { // if middle + to_remove->prev->next = to_remove->next; + to_remove->next->prev = to_remove->prev; + } else if (to_remove->prev && to_remove == tail_) { // if tail + tail_ = to_remove->prev; + tail_->next = nullptr; + } else if (to_remove == head_ && to_remove->next) { // if head + head_ = to_remove->next; + head_->prev = nullptr; + } else if (to_remove == head_ && to_remove == tail_) { // if only element + head_ = nullptr; + tail_ = nullptr; + } + + delete to_remove; + } + + std::multimap buffer_pool_; + BufferHolder* head_{nullptr}; + BufferHolder* tail_{nullptr}; + size_t pool_size_{0}; + + const size_t page_size_; + std::function get_size_; + std::function free_; +}; + +} // namespace mlx::core diff --git a/mlx/backend/metal/allocator.cpp b/mlx/backend/metal/allocator.cpp index 5d8bd90d5..dd6189732 100644 --- a/mlx/backend/metal/allocator.cpp +++ b/mlx/backend/metal/allocator.cpp @@ -30,141 +30,18 @@ void* Buffer::raw_ptr() { namespace metal { -namespace { - -BufferCache::BufferCache(ResidencySet& residency_set) - : head_(nullptr), - tail_(nullptr), - pool_size_(0), - residency_set_(residency_set) {} - -BufferCache::~BufferCache() { - auto pool = metal::new_scoped_memory_pool(); - clear(); -} - -int BufferCache::clear() { - int n_release = 0; - for (auto& [size, holder] : buffer_pool_) { - if (holder->buf) { - if (!holder->buf->heap()) { - residency_set_.erase(holder->buf); - } - holder->buf->release(); - n_release++; - } - delete holder; - } - buffer_pool_.clear(); - pool_size_ = 0; - head_ = nullptr; - tail_ = nullptr; - return n_release; -} - -MTL::Buffer* BufferCache::reuse_from_cache(size_t size) { - // Find the closest buffer in pool - MTL::Buffer* pbuf = nullptr; - - auto it = buffer_pool_.lower_bound(size); - - // Make sure we use most of the available memory - while (!pbuf && it != buffer_pool_.end() && - it->first < std::min(2 * size, size + 2 * vm_page_size)) { - // Collect from the cache - pbuf = it->second->buf; - - // Remove from cache - remove_from_list(it->second); - delete it->second; - it = buffer_pool_.erase(it); - } - - if (pbuf) { - pool_size_ -= pbuf->length(); - } - - return pbuf; -} - -void BufferCache::recycle_to_cache(MTL::Buffer* buf) { - // Add to cache - if (buf) { - BufferHolder* bh = new BufferHolder(buf); - add_at_head(bh); - pool_size_ += buf->length(); - buffer_pool_.insert({buf->length(), bh}); - } -} - -int BufferCache::release_cached_buffers(size_t min_bytes_to_free) { - if (min_bytes_to_free >= 0.9 * pool_size_) { - return clear(); - } else { - int n_release = 0; - size_t total_bytes_freed = 0; - - while (tail_ && (total_bytes_freed < min_bytes_to_free)) { - if (tail_->buf) { - total_bytes_freed += tail_->buf->length(); - if (!tail_->buf->heap()) { - residency_set_.erase(tail_->buf); - } - tail_->buf->release(); - tail_->buf = nullptr; - n_release++; - } - remove_from_list(tail_); - } - pool_size_ -= total_bytes_freed; - return n_release; - } -} - -void BufferCache::add_at_head(BufferCache::BufferHolder* to_add) { - if (!to_add) - return; - - if (!head_) { - head_ = to_add; - tail_ = to_add; - } else { - head_->prev = to_add; - to_add->next = head_; - head_ = to_add; - } -} - -void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) { - if (!to_remove) { - return; - } - - // If in the middle - if (to_remove->prev && to_remove->next) { - to_remove->prev->next = to_remove->next; - to_remove->next->prev = to_remove->prev; - } else if (to_remove->prev && to_remove == tail_) { // If tail - tail_ = to_remove->prev; - tail_->next = nullptr; - } else if (to_remove == head_ && to_remove->next) { // If head - head_ = to_remove->next; - head_->prev = nullptr; - } else if (to_remove == head_ && to_remove == tail_) { // If only element - head_ = nullptr; - tail_ = nullptr; - } - - to_remove->prev = nullptr; - to_remove->next = nullptr; -} - -} // namespace - MetalAllocator::MetalAllocator() : device_(device(mlx::core::Device::gpu).mtl_device()), residency_set_(device_), - buffer_cache_(residency_set_) { + buffer_cache_( + vm_page_size, + [](MTL::Buffer* buf) { return buf->length(); }, + [this](MTL::Buffer* buf) { + if (!buf->heap()) { + residency_set_.erase(buf); + } + buf->release(); + }) { auto pool = metal::new_scoped_memory_pool(); auto memsize = std::get(device_info().at("memory_size")); auto max_rec_size = @@ -193,6 +70,7 @@ MetalAllocator::~MetalAllocator() { if (heap_) { heap_->release(); } + buffer_cache_.clear(); } size_t MetalAllocator::set_cache_limit(size_t limit) { diff --git a/mlx/backend/metal/allocator.h b/mlx/backend/metal/allocator.h index 227b09e91..691317916 100644 --- a/mlx/backend/metal/allocator.h +++ b/mlx/backend/metal/allocator.h @@ -7,6 +7,7 @@ #include #include "mlx/allocator.h" +#include "mlx/backend/common/buffer_cache.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/resident.h" @@ -14,43 +15,6 @@ namespace mlx::core::metal { using allocator::Buffer; -namespace { - -class BufferCache { - public: - BufferCache(ResidencySet& residency_set); - ~BufferCache(); - - MTL::Buffer* reuse_from_cache(size_t size); - void recycle_to_cache(MTL::Buffer* buf); - int release_cached_buffers(size_t min_bytes_to_free); - size_t cache_size() { - return pool_size_; - } - int clear(); - - private: - struct BufferHolder { - public: - BufferHolder(MTL::Buffer* buf_) : buf(buf_), prev(nullptr), next(nullptr) {} - - BufferHolder* prev; - BufferHolder* next; - MTL::Buffer* buf; - }; - - void add_at_head(BufferHolder* to_add); - void remove_from_list(BufferHolder* to_remove); - - std::multimap buffer_pool_; - BufferHolder* head_; - BufferHolder* tail_; - size_t pool_size_; - ResidencySet& residency_set_; -}; - -} // namespace - class MetalAllocator : public allocator::Allocator { /** Allocator for Metal GPUs. */ public: @@ -90,7 +54,7 @@ class MetalAllocator : public allocator::Allocator { friend MetalAllocator& allocator(); // Caching allocator - BufferCache buffer_cache_; + BufferCache buffer_cache_; ResidencySet residency_set_;