diff --git a/mlx/backend/metal/allocator.cpp b/mlx/backend/metal/allocator.cpp index 5d95cae4a..51cc3a109 100644 --- a/mlx/backend/metal/allocator.cpp +++ b/mlx/backend/metal/allocator.cpp @@ -52,8 +52,8 @@ MTL::Buffer* BufferCache::reuse_from_cache(size_t size) { MTL::Buffer* pbuf = nullptr; // Make sure we use > 50% of the available memory - if (auto it = buffer_pool_.lower_bound(size); - it != buffer_pool_.end() && it->first < 2 * size) { + if (auto it = buffer_pool_.lower_bound(size); it != buffer_pool_.end() && + it->first < std::min(2 * size, size + vm_page_size)) { // Collect from the cache pbuf = it->second->buf; // Remove from cache @@ -150,17 +150,18 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) { size = vm_page_size * ((size + vm_page_size - 1) / vm_page_size); } + // Try the cache MTL::Buffer* buf = buffer_cache_.reuse_from_cache(size); - // Prepare to allocate new memory as needed if (!buf) { - // If there is still too much memory pressure, fail (likely causes a wait). - if (auto new_size = device_->currentAllocatedSize() + size; - new_size >= block_limit_) { + // If we have memory pressure, first check if we can reclaim some memory + // from the cache + if (device_->currentAllocatedSize() + size >= block_limit_) { buffer_cache_.clear(); } - if (device_->currentAllocatedSize() >= block_limit_) { + // If there is still too much memory pressure, fail (likely causes a wait). + if (!allow_swap && device_->currentAllocatedSize() + size >= block_limit_) { return Buffer{nullptr}; }