From ad5036072c6c481cac5086f7695b1d528a50fe78 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 25 Dec 2023 12:50:38 -0800 Subject: [PATCH] try no cache --- mlx/backend/metal/allocator.cpp | 155 +++++--------------------------- mlx/backend/metal/allocator.h | 46 +--------- 2 files changed, 21 insertions(+), 180 deletions(-) diff --git a/mlx/backend/metal/allocator.cpp b/mlx/backend/metal/allocator.cpp index 0ac9655a1..dab2c1ae0 100644 --- a/mlx/backend/metal/allocator.cpp +++ b/mlx/backend/metal/allocator.cpp @@ -1,5 +1,6 @@ // Copyright © 2023 Apple Inc. +#include #include "mlx/backend/metal/allocator.h" #include "mlx/backend/metal/metal.h" @@ -23,155 +24,40 @@ void* Buffer::raw_ptr() { namespace metal { -namespace { - -BufferCache::BufferCache(MTL::Device* device) - : device_(device), head_(nullptr), tail_(nullptr), pool_size_(0) {} - -BufferCache::~BufferCache() { - clear(); -} - -void BufferCache::clear() { - std::lock_guard lk(cache_mutex_); - for (auto& [size, holder] : buffer_pool_) { - if (holder->buf) - holder->buf->release(); - delete holder; - } - buffer_pool_.clear(); - pool_size_ = 0; - head_ = nullptr; - tail_ = nullptr; -} - -MTL::Buffer* BufferCache::reuse_from_cache(size_t size) { - std::lock_guard lk(cache_mutex_); - - // Find the closest buffer in pool - MTL::Buffer* pbuf = nullptr; - - // Make sure we use > 50% of the available memory - if (auto it = buffer_pool_.lower_bound(size); - it != buffer_pool_.end() && it->first < 2 * size) { - // Collect from the cache - pbuf = it->second->buf; - // Remove from cache - remove_from_list(it->second); - delete it->second; - it = buffer_pool_.erase(it); - } - - if (pbuf) { - pool_size_ -= pbuf->length(); - } - - return pbuf; -} - -void BufferCache::recycle_to_cache(MTL::Buffer* buf) { - std::lock_guard lk(cache_mutex_); - - // Add to cache - if (buf) { - BufferHolder* bh = new BufferHolder(buf); - add_at_head(bh); - pool_size_ += buf->length(); - buffer_pool_.insert({buf->length(), bh}); - } -} - -void BufferCache::release_cached_buffers(size_t min_bytes_to_free) { - if (min_bytes_to_free >= 0.9 * pool_size_) { - clear(); - } else { - std::lock_guard lk(cache_mutex_); - size_t total_bytes_freed = 0; - while (tail_ && (total_bytes_freed < min_bytes_to_free)) { - if (tail_->buf) { - total_bytes_freed += tail_->buf->length(); - tail_->buf->release(); - tail_->buf = nullptr; - } - remove_from_list(tail_); - } - pool_size_ -= total_bytes_freed; - } -} - -void BufferCache::add_at_head(BufferCache::BufferHolder* to_add) { - if (!to_add) - return; - - if (!head_) { - head_ = to_add; - tail_ = to_add; - } else { - head_->prev = to_add; - to_add->next = head_; - head_ = to_add; - } -} - -void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) { - if (!to_remove) - return; - - // If in the middle - if (to_remove->prev && to_remove->next) { - to_remove->prev->next = to_remove->next; - to_remove->next->prev = to_remove->prev; - } else if (to_remove->prev && to_remove == tail_) { // If tail - tail_ = to_remove->prev; - tail_->next = nullptr; - } else if (to_remove == head_ && to_remove->next) { // If head - head_ = to_remove->next; - head_->prev = nullptr; - } else if (to_remove == head_ && to_remove == tail_) { // If only element - head_ = nullptr; - tail_ = nullptr; - } - - to_remove->prev = nullptr; - to_remove->next = nullptr; -} - -} // namespace - MetalAllocator::MetalAllocator() : device_(device(mlx::core::Device::gpu).mtl_device()), - buffer_cache_(device_), peak_allocated_size_(0), - block_limit_(1.5 * device_->recommendedMaxWorkingSetSize()) {} + block_limit_(device_->recommendedMaxWorkingSetSize()) {} -Buffer MetalAllocator::malloc(size_t size) { +Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) { // Align up memory - if (size > vm_page_size) { - size = vm_page_size * ((size + vm_page_size - 1) / vm_page_size); - } + ///if (size > vm_page_size) { + // size = vm_page_size * ((size + vm_page_size - 1) / vm_page_size); + //} - MTL::Buffer* buf = buffer_cache_.reuse_from_cache(size); +// MTL::Buffer* buf = buffer_cache_.reuse_from_cache(size); // Prepare to allocate new memory as needed - if (!buf) { - // First check if the cache is big but nothing fits, garbage collect - // if so - // TODO maybe block limit and gc limit should be different - if (buffer_cache_.size() >= block_limit_) { - buffer_cache_.release_cached_buffers( - std::max(buffer_cache_.size() - block_limit_, size)); - } +// if (!buf) { + // If we have memory pressure, first check if we can reclaim some memory + // from the cache +// if (auto new_size = device_->currentAllocatedSize() + size; new_size >= block_limit_) { +// buffer_cache_.clear(); +// buffer_cache_.release_cached_buffers( +// std::max(new_size - block_limit_, size)); +// } // If there is still too much memory pressure, fail (likely causes a wait). - if (device_->currentAllocatedSize() >= block_limit_) { + // size + allocated (to avoid going over the limit) + if (!allow_swap && device_->currentAllocatedSize() + size >= block_limit_) { return Buffer{nullptr}; } +// } // Allocate new buffer if needed size_t res_opt = MTL::ResourceStorageModeShared; res_opt |= MTL::ResourceHazardTrackingModeTracked; - buf = device_->newBuffer(size, res_opt); - } + auto buf = device_->newBuffer(size, res_opt); peak_allocated_size_ = std::max(peak_allocated_size_, device_->currentAllocatedSize()); @@ -180,8 +66,7 @@ Buffer MetalAllocator::malloc(size_t size) { } void MetalAllocator::free(Buffer buffer) { - auto buf = static_cast(buffer.ptr()); - buffer_cache_.recycle_to_cache(buf); + static_cast(buffer.ptr())->release(); } MetalAllocator& allocator() { diff --git a/mlx/backend/metal/allocator.h b/mlx/backend/metal/allocator.h index ff5d902c6..a3a790b41 100644 --- a/mlx/backend/metal/allocator.h +++ b/mlx/backend/metal/allocator.h @@ -13,51 +13,10 @@ namespace mlx::core::metal { using allocator::Buffer; -namespace { - -class BufferCache { - public: - BufferCache(MTL::Device* device); - ~BufferCache(); - void clear(); - - MTL::Buffer* reuse_from_cache(size_t size); - void recycle_to_cache(MTL::Buffer* buf); - void release_cached_buffers(size_t min_bytes_to_free); - - // Returnt he size in bytes of cached memory - size_t size() { - return pool_size_; - } - - private: - struct BufferHolder { - public: - BufferHolder(MTL::Buffer* buf_) : buf(buf_), prev(nullptr), next(nullptr) {} - - BufferHolder* prev; - BufferHolder* next; - MTL::Buffer* buf; - }; - - void add_at_head(BufferHolder* to_add); - void remove_from_list(BufferHolder* to_remove); - - MTL::Device* device_; - std::mutex cache_mutex_; - - std::multimap buffer_pool_; - BufferHolder* head_; - BufferHolder* tail_; - size_t pool_size_; -}; - -} // namespace - class MetalAllocator : public allocator::Allocator { /** Allocator for Metal GPUs. */ public: - virtual Buffer malloc(size_t size) override; + virtual Buffer malloc(size_t size, bool allow_swap = false) override; virtual void free(Buffer buffer) override; private: @@ -65,9 +24,6 @@ class MetalAllocator : public allocator::Allocator { MetalAllocator(); friend MetalAllocator& allocator(); - // Caching allocator - BufferCache buffer_cache_; - // Allocation stats size_t peak_allocated_size_; size_t block_limit_;