diff --git a/mlx/backend/metal/allocator.cpp b/mlx/backend/metal/allocator.cpp index af4dd2e36..0ac9655a1 100644 --- a/mlx/backend/metal/allocator.cpp +++ b/mlx/backend/metal/allocator.cpp @@ -26,11 +26,7 @@ namespace metal { namespace { BufferCache::BufferCache(MTL::Device* device) - : device_(device), - head_(nullptr), - tail_(nullptr), - pool_size_(0), - gc_limit_(0.95 * device_->recommendedMaxWorkingSetSize()) {} + : device_(device), head_(nullptr), tail_(nullptr), pool_size_(0) {} BufferCache::~BufferCache() { clear(); @@ -54,10 +50,10 @@ MTL::Buffer* BufferCache::reuse_from_cache(size_t size) { // Find the closest buffer in pool MTL::Buffer* pbuf = nullptr; - auto it = buffer_pool_.lower_bound(size); // Make sure we use > 50% of the available memory - while (!pbuf && it != buffer_pool_.end() && it->first < 2 * size) { + if (auto it = buffer_pool_.lower_bound(size); + it != buffer_pool_.end() && it->first < 2 * size) { // Collect from the cache pbuf = it->second->buf; // Remove from cache @@ -85,17 +81,12 @@ void BufferCache::recycle_to_cache(MTL::Buffer* buf) { } } -size_t BufferCache::release_cached_buffers(size_t min_bytes_to_free) { - min_bytes_to_free += device_->currentAllocatedSize() - gc_limit_; - +void BufferCache::release_cached_buffers(size_t min_bytes_to_free) { if (min_bytes_to_free >= 0.9 * pool_size_) { - size_t old_pool_size = pool_size_; clear(); - return old_pool_size; } else { std::lock_guard lk(cache_mutex_); size_t total_bytes_freed = 0; - while (tail_ && (total_bytes_freed < min_bytes_to_free)) { if (tail_->buf) { total_bytes_freed += tail_->buf->length(); @@ -104,9 +95,7 @@ size_t BufferCache::release_cached_buffers(size_t min_bytes_to_free) { } remove_from_list(tail_); } - pool_size_ -= total_bytes_freed; - return total_bytes_freed; } } @@ -165,14 +154,17 @@ Buffer MetalAllocator::malloc(size_t size) { // Prepare to allocate new memory as needed if (!buf) { - // If we are under very high memory pressure, we don't allocate further - if (device_->currentAllocatedSize() >= block_limit_) { - return Buffer{nullptr}; + // First check if the cache is big but nothing fits, garbage collect + // if so + // TODO maybe block limit and gc limit should be different + if (buffer_cache_.size() >= block_limit_) { + buffer_cache_.release_cached_buffers( + std::max(buffer_cache_.size() - block_limit_, size)); } - // If we are still under memory pressure, try cleaning cache - if (buffer_cache_.can_garbage_collect()) { - buffer_cache_.release_cached_buffers(size); + // If there is still too much memory pressure, fail (likely causes a wait). + if (device_->currentAllocatedSize() >= block_limit_) { + return Buffer{nullptr}; } // Allocate new buffer if needed diff --git a/mlx/backend/metal/allocator.h b/mlx/backend/metal/allocator.h index 621b78a92..ff5d902c6 100644 --- a/mlx/backend/metal/allocator.h +++ b/mlx/backend/metal/allocator.h @@ -23,10 +23,11 @@ class BufferCache { MTL::Buffer* reuse_from_cache(size_t size); void recycle_to_cache(MTL::Buffer* buf); - size_t release_cached_buffers(size_t min_bytes_to_free); + void release_cached_buffers(size_t min_bytes_to_free); - bool can_garbage_collect() { - return pool_size_ > 0 && device_->currentAllocatedSize() > gc_limit_; + // Returnt he size in bytes of cached memory + size_t size() { + return pool_size_; } private: @@ -49,7 +50,6 @@ class BufferCache { BufferHolder* head_; BufferHolder* tail_; size_t pool_size_; - size_t gc_limit_; }; } // namespace