diff --git a/mlx/backend/common/buffer_cache.h b/mlx/backend/common/buffer_cache.h index 12ecd80ce..92b20f222 100644 --- a/mlx/backend/common/buffer_cache.h +++ b/mlx/backend/common/buffer_cache.h @@ -2,6 +2,7 @@ #pragma once +#include #include #include @@ -27,37 +28,30 @@ class BufferCache { T* reuse_from_cache(size_t size) { // Find the closest buffer in pool. - T* pbuf = nullptr; - auto it = buffer_pool_.lower_bound(size); - - // Make sure we use most of the available memory. - while (!pbuf && it != buffer_pool_.end() && - it->first < std::min(2 * size, size + 2 * page_size_)) { - // Collect from the cache. - pbuf = it->second->buf; - - // Remove from cache. - remove_from_list(it->second); - it = buffer_pool_.erase(it); + if (it == buffer_pool_.end() || + it->first >= std::min(2 * size, size + 2 * page_size_)) { + return nullptr; } - if (pbuf) { - pool_size_ -= get_size_(pbuf); - } + // Collect from the cache. + T* buf = it->second->buf; + pool_size_ -= it->first; - return pbuf; + // Remove from record. + remove_from_list(it->second); + buffer_pool_.erase(it); + return buf; } void recycle_to_cache(T* buf) { + assert(buf); // Add to cache. - if (buf) { - BufferHolder* bh = new BufferHolder(buf); - add_at_head(bh); - size_t size = get_size_(buf); - pool_size_ += size; - buffer_pool_.insert({size, bh}); - } + BufferHolder* bh = new BufferHolder(buf); + add_at_head(bh); + size_t size = get_size_(buf); + pool_size_ += size; + buffer_pool_.emplace(size, bh); } int release_cached_buffers(size_t min_bytes_to_free) { @@ -68,20 +62,22 @@ class BufferCache { size_t total_bytes_freed = 0; while (tail_ && (total_bytes_freed < min_bytes_to_free)) { - if (tail_->buf) { - total_bytes_freed += get_size_(tail_->buf); - free_(tail_->buf); - tail_->buf = nullptr; - n_release++; - } + // Release buffer. + size_t size = get_size_(tail_->buf); + total_bytes_freed += size; + free_(tail_->buf); + n_release++; + + // Remove from record. + auto its = buffer_pool_.equal_range(size); + auto it = std::find_if(its.first, its.second, [this](const auto& el) { + return el.second == tail_; + }); + assert(it != buffer_pool_.end()); + buffer_pool_.erase(it); remove_from_list(tail_); - for (auto it = buffer_pool_.begin(); it != buffer_pool_.end(); ++it) { - if (it->second == tail_) { - buffer_pool_.erase(it); - break; - } - } } + pool_size_ -= total_bytes_freed; return n_release; } @@ -90,10 +86,8 @@ class BufferCache { int clear() { int n_release = 0; for (auto& [size, holder] : buffer_pool_) { - if (holder->buf) { - free_(holder->buf); - n_release++; - } + free_(holder->buf); + n_release++; delete holder; } buffer_pool_.clear();