diff --git a/mlx/backend/metal/allocator.cpp b/mlx/backend/metal/allocator.cpp index a8db43bb1..5d95cae4a 100644 --- a/mlx/backend/metal/allocator.cpp +++ b/mlx/backend/metal/allocator.cpp @@ -23,41 +23,152 @@ void* Buffer::raw_ptr() { namespace metal { +namespace { + +BufferCache::BufferCache(MTL::Device* device) + : device_(device), head_(nullptr), tail_(nullptr), pool_size_(0) {} + +BufferCache::~BufferCache() { + clear(); +} + +void BufferCache::clear() { + std::lock_guard lk(cache_mutex_); + for (auto& [size, holder] : buffer_pool_) { + if (holder->buf) + holder->buf->release(); + delete holder; + } + buffer_pool_.clear(); + pool_size_ = 0; + head_ = nullptr; + tail_ = nullptr; +} + +MTL::Buffer* BufferCache::reuse_from_cache(size_t size) { + std::lock_guard lk(cache_mutex_); + + // Find the closest buffer in pool + MTL::Buffer* pbuf = nullptr; + + // Make sure we use > 50% of the available memory + if (auto it = buffer_pool_.lower_bound(size); + it != buffer_pool_.end() && it->first < 2 * size) { + // Collect from the cache + pbuf = it->second->buf; + // Remove from cache + remove_from_list(it->second); + delete it->second; + it = buffer_pool_.erase(it); + } + + if (pbuf) { + pool_size_ -= pbuf->length(); + } + + return pbuf; +} + +void BufferCache::recycle_to_cache(MTL::Buffer* buf) { + std::lock_guard lk(cache_mutex_); + + // Add to cache + if (buf) { + BufferHolder* bh = new BufferHolder(buf); + add_at_head(bh); + pool_size_ += buf->length(); + buffer_pool_.insert({buf->length(), bh}); + } +} + +void BufferCache::release_cached_buffers(size_t min_bytes_to_free) { + if (min_bytes_to_free >= 0.9 * pool_size_) { + clear(); + } else { + std::lock_guard lk(cache_mutex_); + size_t total_bytes_freed = 0; + while (tail_ && (total_bytes_freed < min_bytes_to_free)) { + if (tail_->buf) { + total_bytes_freed += tail_->buf->length(); + tail_->buf->release(); + tail_->buf = nullptr; + } + remove_from_list(tail_); + } + pool_size_ -= total_bytes_freed; + } +} + +void BufferCache::add_at_head(BufferCache::BufferHolder* to_add) { + if (!to_add) + return; + + if (!head_) { + head_ = to_add; + tail_ = to_add; + } else { + head_->prev = to_add; + to_add->next = head_; + head_ = to_add; + } +} + +void BufferCache::remove_from_list(BufferCache::BufferHolder* to_remove) { + if (!to_remove) + return; + + // If in the middle + if (to_remove->prev && to_remove->next) { + to_remove->prev->next = to_remove->next; + to_remove->next->prev = to_remove->prev; + } else if (to_remove->prev && to_remove == tail_) { // If tail + tail_ = to_remove->prev; + tail_->next = nullptr; + } else if (to_remove == head_ && to_remove->next) { // If head + head_ = to_remove->next; + head_->prev = nullptr; + } else if (to_remove == head_ && to_remove == tail_) { // If only element + head_ = nullptr; + tail_ = nullptr; + } + + to_remove->prev = nullptr; + to_remove->next = nullptr; +} + +} // namespace + MetalAllocator::MetalAllocator() : device_(device(mlx::core::Device::gpu).mtl_device()), + buffer_cache_(device_), peak_allocated_size_(0), block_limit_(device_->recommendedMaxWorkingSetSize()) {} Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) { // Align up memory - /// if (size > vm_page_size) { - // size = vm_page_size * ((size + vm_page_size - 1) / vm_page_size); - //} + if (size > vm_page_size) { + size = vm_page_size * ((size + vm_page_size - 1) / vm_page_size); + } - // MTL::Buffer* buf = buffer_cache_.reuse_from_cache(size); + MTL::Buffer* buf = buffer_cache_.reuse_from_cache(size); // Prepare to allocate new memory as needed - // if (!buf) { - // If we have memory pressure, first check if we can reclaim some memory - // from the cache - // if (auto new_size = device_->currentAllocatedSize() + size; new_size >= - // block_limit_) { - // buffer_cache_.clear(); - // buffer_cache_.release_cached_buffers( - // std::max(new_size - block_limit_, size)); - // } + if (!buf) { + // If there is still too much memory pressure, fail (likely causes a wait). + if (auto new_size = device_->currentAllocatedSize() + size; + new_size >= block_limit_) { + buffer_cache_.clear(); + } - // If there is still too much memory pressure, fail (likely causes a wait). - // size + allocated (to avoid going over the limit) - if (!allow_swap && device_->currentAllocatedSize() + size >= block_limit_) { - return Buffer{nullptr}; + if (device_->currentAllocatedSize() >= block_limit_) { + return Buffer{nullptr}; + } + + // Allocate new buffer if needed + size_t res_opt = MTL::ResourceStorageModeShared; + res_opt |= MTL::ResourceHazardTrackingModeTracked; + buf = device_->newBuffer(size, res_opt); } - // } - - // Allocate new buffer if needed - size_t res_opt = MTL::ResourceStorageModeShared; - res_opt |= MTL::ResourceHazardTrackingModeTracked; - auto buf = device_->newBuffer(size, res_opt); peak_allocated_size_ = std::max(peak_allocated_size_, device_->currentAllocatedSize()); @@ -66,7 +177,8 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) { } void MetalAllocator::free(Buffer buffer) { - static_cast(buffer.ptr())->release(); + auto buf = static_cast(buffer.ptr()); + buffer_cache_.recycle_to_cache(buf); } MetalAllocator& allocator() { diff --git a/mlx/backend/metal/allocator.h b/mlx/backend/metal/allocator.h index a3a790b41..24b66653e 100644 --- a/mlx/backend/metal/allocator.h +++ b/mlx/backend/metal/allocator.h @@ -13,6 +13,42 @@ namespace mlx::core::metal { using allocator::Buffer; +namespace { + +class BufferCache { + public: + BufferCache(MTL::Device* device); + ~BufferCache(); + void clear(); + + MTL::Buffer* reuse_from_cache(size_t size); + void recycle_to_cache(MTL::Buffer* buf); + void release_cached_buffers(size_t min_bytes_to_free); + + private: + struct BufferHolder { + public: + BufferHolder(MTL::Buffer* buf_) : buf(buf_), prev(nullptr), next(nullptr) {} + + BufferHolder* prev; + BufferHolder* next; + MTL::Buffer* buf; + }; + + void add_at_head(BufferHolder* to_add); + void remove_from_list(BufferHolder* to_remove); + + MTL::Device* device_; + std::mutex cache_mutex_; + + std::multimap buffer_pool_; + BufferHolder* head_; + BufferHolder* tail_; + size_t pool_size_; +}; + +} // namespace + class MetalAllocator : public allocator::Allocator { /** Allocator for Metal GPUs. */ public: @@ -24,6 +60,9 @@ class MetalAllocator : public allocator::Allocator { MetalAllocator(); friend MetalAllocator& allocator(); + // Caching allocator + BufferCache buffer_cache_; + // Allocation stats size_t peak_allocated_size_; size_t block_limit_;