diff --git a/mlx/allocator.cpp b/mlx/allocator.cpp index b591caeeb..86fa5974f 100644 --- a/mlx/allocator.cpp +++ b/mlx/allocator.cpp @@ -9,7 +9,7 @@ namespace mlx::core::allocator { Buffer malloc(size_t size) { - auto buffer = allocator().malloc(size); + auto buffer = allocator().malloc(size, /* allow_swap */ true); if (size && !buffer.ptr()) { std::ostringstream msg; msg << "[malloc] Unable to allocate " << size << " bytes."; @@ -22,7 +22,7 @@ void free(Buffer buffer) { return allocator().free(buffer); } -Buffer CommonAllocator::malloc(size_t size) { +Buffer CommonAllocator::malloc(size_t size, bool) { return Buffer{std::malloc(size)}; } @@ -38,6 +38,11 @@ Buffer malloc_or_wait(size_t size) { buffer = allocator().malloc(size); } + // Try swapping if needed + if (size && !buffer.ptr()) { + buffer = allocator().malloc(size, /* allow_swap = */ true); + } + if (size && !buffer.ptr()) { std::ostringstream msg; msg << "[malloc_or_wait] Unable to allocate " << size << " bytes."; diff --git a/mlx/allocator.h b/mlx/allocator.h index ce0c1cd00..1061d6cce 100644 --- a/mlx/allocator.h +++ b/mlx/allocator.h @@ -39,7 +39,7 @@ Buffer malloc_or_wait(size_t size); class Allocator { /** Abstract base class for a memory allocator. */ public: - virtual Buffer malloc(size_t size) = 0; + virtual Buffer malloc(size_t size, bool allow_swap = false) = 0; virtual void free(Buffer buffer) = 0; Allocator() = default; @@ -55,7 +55,7 @@ Allocator& allocator(); class CommonAllocator : public Allocator { /** A general CPU allocator. */ public: - virtual Buffer malloc(size_t size) override; + virtual Buffer malloc(size_t size, bool allow_swap = false) override; virtual void free(Buffer buffer) override; private: diff --git a/mlx/backend/metal/allocator.cpp b/mlx/backend/metal/allocator.cpp index af4dd2e36..07f502998 100644 --- a/mlx/backend/metal/allocator.cpp +++ b/mlx/backend/metal/allocator.cpp @@ -26,11 +26,7 @@ namespace metal { namespace { BufferCache::BufferCache(MTL::Device* device) - : device_(device), - head_(nullptr), - tail_(nullptr), - pool_size_(0), - gc_limit_(0.95 * device_->recommendedMaxWorkingSetSize()) {} + : device_(device), head_(nullptr), tail_(nullptr), pool_size_(0) {} BufferCache::~BufferCache() { clear(); @@ -54,12 +50,16 @@ MTL::Buffer* BufferCache::reuse_from_cache(size_t size) { // Find the closest buffer in pool MTL::Buffer* pbuf = nullptr; + + // Make sure we use most of the available memory auto it = buffer_pool_.lower_bound(size); - // Make sure we use > 50% of the available memory - while (!pbuf && it != buffer_pool_.end() && it->first < 2 * size) { + // Make sure we use most of the available memory + while (!pbuf && it != buffer_pool_.end() && + it->first < std::min(2 * size, size + 2 * vm_page_size)) { // Collect from the cache pbuf = it->second->buf; + // Remove from cache remove_from_list(it->second); delete it->second; @@ -85,13 +85,9 @@ void BufferCache::recycle_to_cache(MTL::Buffer* buf) { } } -size_t BufferCache::release_cached_buffers(size_t min_bytes_to_free) { - min_bytes_to_free += device_->currentAllocatedSize() - gc_limit_; - +void BufferCache::release_cached_buffers(size_t min_bytes_to_free) { if (min_bytes_to_free >= 0.9 * pool_size_) { - size_t old_pool_size = pool_size_; clear(); - return old_pool_size; } else { std::lock_guard lk(cache_mutex_); size_t total_bytes_freed = 0; @@ -104,9 +100,7 @@ size_t BufferCache::release_cached_buffers(size_t min_bytes_to_free) { } remove_from_list(tail_); } - pool_size_ -= total_bytes_freed; - return total_bytes_freed; } } @@ -125,8 +119,9 @@ void BufferCache::add_at_head(BufferCache::BufferHolder* to_add) { } void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) { - if (!to_remove) + if (!to_remove) { return; + } // If in the middle if (to_remove->prev && to_remove->next) { @@ -153,26 +148,30 @@ MetalAllocator::MetalAllocator() : device_(device(mlx::core::Device::gpu).mtl_device()), buffer_cache_(device_), peak_allocated_size_(0), - block_limit_(1.5 * device_->recommendedMaxWorkingSetSize()) {} + block_limit_(1.5 * device_->recommendedMaxWorkingSetSize()), + gc_limit_(0.95 * device_->recommendedMaxWorkingSetSize()) {} -Buffer MetalAllocator::malloc(size_t size) { +Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) { // Align up memory if (size > vm_page_size) { size = vm_page_size * ((size + vm_page_size - 1) / vm_page_size); } + // Try the cache MTL::Buffer* buf = buffer_cache_.reuse_from_cache(size); - // Prepare to allocate new memory as needed if (!buf) { - // If we are under very high memory pressure, we don't allocate further - if (device_->currentAllocatedSize() >= block_limit_) { + // If there is too much memory pressure, fail (likely causes a wait). + if (!allow_swap && device_->currentAllocatedSize() + size >= block_limit_) { return Buffer{nullptr}; } - // If we are still under memory pressure, try cleaning cache - if (buffer_cache_.can_garbage_collect()) { - buffer_cache_.release_cached_buffers(size); + // If we have a lot of memory pressure, check if we can reclaim some memory + // from the cache + if (device_->currentAllocatedSize() + size >= gc_limit_) { + size_t min_bytes_to_free = + size + device_->currentAllocatedSize() - gc_limit_; + buffer_cache_.release_cached_buffers(min_bytes_to_free); } // Allocate new buffer if needed diff --git a/mlx/backend/metal/allocator.h b/mlx/backend/metal/allocator.h index 621b78a92..45a58bc13 100644 --- a/mlx/backend/metal/allocator.h +++ b/mlx/backend/metal/allocator.h @@ -23,11 +23,7 @@ class BufferCache { MTL::Buffer* reuse_from_cache(size_t size); void recycle_to_cache(MTL::Buffer* buf); - size_t release_cached_buffers(size_t min_bytes_to_free); - - bool can_garbage_collect() { - return pool_size_ > 0 && device_->currentAllocatedSize() > gc_limit_; - } + void release_cached_buffers(size_t min_bytes_to_free); private: struct BufferHolder { @@ -49,7 +45,6 @@ class BufferCache { BufferHolder* head_; BufferHolder* tail_; size_t pool_size_; - size_t gc_limit_; }; } // namespace @@ -57,7 +52,7 @@ class BufferCache { class MetalAllocator : public allocator::Allocator { /** Allocator for Metal GPUs. */ public: - virtual Buffer malloc(size_t size) override; + virtual Buffer malloc(size_t size, bool allow_swap = false) override; virtual void free(Buffer buffer) override; private: @@ -71,6 +66,7 @@ class MetalAllocator : public allocator::Allocator { // Allocation stats size_t peak_allocated_size_; size_t block_limit_; + size_t gc_limit_; }; MetalAllocator& allocator();