diff --git a/mlx/backend/metal/allocator.cpp b/mlx/backend/metal/allocator.cpp index 0eec44bfa..9911acef1 100644 --- a/mlx/backend/metal/allocator.cpp +++ b/mlx/backend/metal/allocator.cpp @@ -33,8 +33,11 @@ namespace metal { namespace { -BufferCache::BufferCache(MTL::Device* device) - : device_(device), head_(nullptr), tail_(nullptr), pool_size_(0) {} +BufferCache::BufferCache(ResidencySet& residency_set) + : head_(nullptr), + tail_(nullptr), + pool_size_(0), + residency_set_(residency_set) {} BufferCache::~BufferCache() { auto pool = metal::new_scoped_memory_pool(); @@ -102,6 +105,9 @@ int BufferCache::release_cached_buffers(size_t min_bytes_to_free) { while (tail_ && (total_bytes_freed < min_bytes_to_free)) { if (tail_->buf) { total_bytes_freed += tail_->buf->length(); + if (!tail_->buf->heap()) { + residency_set_.erase(tail_->buf); + } tail_->buf->release(); tail_->buf = nullptr; n_release++; @@ -156,7 +162,7 @@ void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) { MetalAllocator::MetalAllocator() : device_(device(mlx::core::Device::gpu).mtl_device()), residency_set_(device_), - buffer_cache_(device_) { + buffer_cache_(residency_set_) { auto pool = metal::new_scoped_memory_pool(); auto memsize = std::get(device_info().at("memory_size")); auto max_rec_size = @@ -298,14 +304,14 @@ void MetalAllocator::free(Buffer buffer) { return; } std::unique_lock lk(mutex_); - if (!buf->heap()) { - residency_set_.erase(buf); - } active_memory_ -= buf->length(); if (get_cache_memory() < max_pool_size_) { buffer_cache_.recycle_to_cache(buf); } else { num_resources_--; + if (!buf->heap()) { + residency_set_.erase(buf); + } lk.unlock(); auto pool = metal::new_scoped_memory_pool(); buf->release(); diff --git a/mlx/backend/metal/allocator.h b/mlx/backend/metal/allocator.h index 8b77ff6c1..227b09e91 100644 --- a/mlx/backend/metal/allocator.h +++ b/mlx/backend/metal/allocator.h @@ -18,7 +18,7 @@ namespace { class BufferCache { public: - BufferCache(MTL::Device* device); + BufferCache(ResidencySet& residency_set); ~BufferCache(); MTL::Buffer* reuse_from_cache(size_t size); @@ -42,13 +42,11 @@ class BufferCache { void add_at_head(BufferHolder* to_add); void remove_from_list(BufferHolder* to_remove); - MTL::Device* device_; - MTL::Heap* heap_{nullptr}; - std::multimap buffer_pool_; BufferHolder* head_; BufferHolder* tail_; size_t pool_size_; + ResidencySet& residency_set_; }; } // namespace