diff --git a/mlx/backend/metal/allocator.cpp b/mlx/backend/metal/allocator.cpp index 5069cbad7..a3e90df83 100644 --- a/mlx/backend/metal/allocator.cpp +++ b/mlx/backend/metal/allocator.cpp @@ -52,10 +52,14 @@ MTL::Buffer* BufferCache::reuse_from_cache(size_t size) { MTL::Buffer* pbuf = nullptr; // Make sure we use most of the available memory - if (auto it = buffer_pool_.lower_bound(size); it != buffer_pool_.end() && - it->first < std::min(2 * size, size + vm_page_size)) { + auto it = buffer_pool_.lower_bound(size); + + // Make sure we use most of the available memory + while (!pbuf && it != buffer_pool_.end() && + it->first < std::min(2 * size, size + 2 * vm_page_size)) { // Collect from the cache pbuf = it->second->buf; + // Remove from cache remove_from_list(it->second); delete it->second; @@ -81,6 +85,25 @@ void BufferCache::recycle_to_cache(MTL::Buffer* buf) { } } +void BufferCache::release_cached_buffers(size_t min_bytes_to_free) { + if (min_bytes_to_free >= 0.9 * pool_size_) { + clear(); + } else { + std::lock_guard lk(cache_mutex_); + size_t total_bytes_freed = 0; + + while (tail_ && (total_bytes_freed < min_bytes_to_free)) { + if (tail_->buf) { + total_bytes_freed += tail_->buf->length(); + tail_->buf->release(); + tail_->buf = nullptr; + } + remove_from_list(tail_); + } + pool_size_ -= total_bytes_freed; + } +} + void BufferCache::add_at_head(BufferCache::BufferHolder* to_add) { if (!to_add) return; @@ -96,8 +119,9 @@ void BufferCache::add_at_head(BufferCache::BufferHolder* to_add) { } void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) { - if (!to_remove) + if (!to_remove) { return; + } // If in the middle if (to_remove->prev && to_remove->next) { @@ -124,7 +148,8 @@ MetalAllocator::MetalAllocator() : device_(device(mlx::core::Device::gpu).mtl_device()), buffer_cache_(device_), peak_allocated_size_(0), - block_limit_(device_->recommendedMaxWorkingSetSize()) {} + block_limit_(device_->recommendedMaxWorkingSetSize()), + gc_limit_(0.95 * device_->recommendedMaxWorkingSetSize()) {} Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) { // Align up memory @@ -136,17 +161,19 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) { MTL::Buffer* buf = buffer_cache_.reuse_from_cache(size); if (!buf) { - // If we have memory pressure, first check if we can reclaim some memory - // from the cache - if (device_->currentAllocatedSize() + size >= block_limit_) { - buffer_cache_.clear(); - } - - // If there is still too much memory pressure, fail (likely causes a wait). + // If there is too much memory pressure, fail (likely causes a wait). if (!allow_swap && device_->currentAllocatedSize() + size >= block_limit_) { return Buffer{nullptr}; } + // If we have a lot of memory pressure, check if we can reclaim some memory + // from the cache + if (device_->currentAllocatedSize() + size >= gc_limit_) { + size_t min_bytes_to_free = + size + device_->currentAllocatedSize() - gc_limit_; + buffer_cache_.release_cached_buffers(min_bytes_to_free); + } + // Allocate new buffer if needed size_t res_opt = MTL::ResourceStorageModeShared; res_opt |= MTL::ResourceHazardTrackingModeTracked; diff --git a/mlx/backend/metal/allocator.h b/mlx/backend/metal/allocator.h index bdb256095..45a58bc13 100644 --- a/mlx/backend/metal/allocator.h +++ b/mlx/backend/metal/allocator.h @@ -23,6 +23,7 @@ class BufferCache { MTL::Buffer* reuse_from_cache(size_t size); void recycle_to_cache(MTL::Buffer* buf); + void release_cached_buffers(size_t min_bytes_to_free); private: struct BufferHolder { @@ -65,6 +66,7 @@ class MetalAllocator : public allocator::Allocator { // Allocation stats size_t peak_allocated_size_; size_t block_limit_; + size_t gc_limit_; }; MetalAllocator& allocator();