From ba12e4999ab379ef99b34f2e3e9d80f355b5e7b5 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 3 Mar 2025 06:50:57 -0800 Subject: [PATCH] Use a heap for small sizes (#1911) * use a heap for small sizes * check if VM --- mlx/backend/metal/allocator.cpp | 49 ++++++++++++++++++++++++++------- mlx/backend/metal/allocator.h | 9 ++++++ mlx/backend/metal/device.cpp | 4 ++- mlx/backend/metal/metal.h | 2 +- mlx/backend/no_metal/metal.cpp | 2 +- 5 files changed, 53 insertions(+), 13 deletions(-) diff --git a/mlx/backend/metal/allocator.cpp b/mlx/backend/metal/allocator.cpp index 7d34fc815..f2c95be20 100644 --- a/mlx/backend/metal/allocator.cpp +++ b/mlx/backend/metal/allocator.cpp @@ -10,6 +10,9 @@ namespace mlx::core { +constexpr size_t resource_options = + MTL::ResourceStorageModeShared | MTL::ResourceHazardTrackingModeUntracked; + namespace allocator { Allocator& allocator() { @@ -150,15 +153,34 @@ MetalAllocator::MetalAllocator() : device_(device(mlx::core::Device::gpu).mtl_device()), residency_set_(device_), buffer_cache_(device_) { - auto memsize = std::get(device_info()["memory_size"]); + auto pool = metal::new_scoped_memory_pool(); + auto memsize = std::get(device_info().at("memory_size")); auto max_rec_size = - std::get(device_info()["max_recommended_working_set_size"]); - resource_limit_ = std::get(device_info()["resource_limit"]); + std::get(device_info().at("max_recommended_working_set_size")); + resource_limit_ = std::get(device_info().at("resource_limit")); block_limit_ = std::min(1.5 * max_rec_size, 0.95 * memsize); gc_limit_ = std::min(static_cast(0.95 * max_rec_size), block_limit_); max_pool_size_ = block_limit_; device(mlx::core::Device::gpu) .set_residency_set(residency_set_.mtl_residency_set()); + bool is_vm = std::get(device_info().at("device_name")) == + "Apple Paravirtual device"; + if (is_vm) { + return; + } + auto heap_desc = MTL::HeapDescriptor::alloc()->init(); + heap_desc->setResourceOptions(resource_options); + heap_desc->setSize(heap_size_); + heap_ = device_->newHeap(heap_desc); + heap_desc->release(); + residency_set_.insert(heap_); +} + +MetalAllocator::~MetalAllocator() { + auto pool = metal::new_scoped_memory_pool(); + if (heap_) { + heap_->release(); + } } size_t MetalAllocator::set_cache_limit(size_t limit) { @@ -226,8 +248,6 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) { } // Allocate new buffer if needed - size_t res_opt = MTL::ResourceStorageModeShared; - res_opt |= MTL::ResourceHazardTrackingModeUntracked; if (num_resources_ >= resource_limit_) { std::ostringstream msg; msg << "[metal::malloc] Resource limit (" << resource_limit_ @@ -235,7 +255,12 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) { throw std::runtime_error(msg.str()); } lk.unlock(); - buf = device_->newBuffer(size, res_opt); + if (size < small_size_ && heap_) { + buf = heap_->newBuffer(size, resource_options); + } + if (!buf) { + buf = device_->newBuffer(size, resource_options); + } lk.lock(); if (buf) { num_resources_++; @@ -246,13 +271,15 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) { peak_memory_ = std::max(peak_memory_, active_memory_); // Maintain the cache below the requested limit - if (get_cache_memory() >= max_pool_size_) { + if (get_cache_memory() > max_pool_size_) { auto pool = metal::new_scoped_memory_pool(); num_resources_ -= buffer_cache_.release_cached_buffers( get_cache_memory() - max_pool_size_); } - residency_set_.insert(buf); + if (!buf->heap()) { + residency_set_.insert(buf); + } return Buffer{static_cast(buf)}; } @@ -269,7 +296,9 @@ void MetalAllocator::free(Buffer buffer) { return; } std::unique_lock lk(mutex_); - residency_set_.erase(buf); + if (!buf->heap()) { + residency_set_.erase(buf); + } active_memory_ -= buf->length(); if (get_cache_memory() < max_pool_size_) { buffer_cache_.recycle_to_cache(buf); @@ -301,7 +330,7 @@ size_t set_memory_limit(size_t limit, bool relaxed /* = true */) { } size_t set_wired_limit(size_t limit) { if (limit > - std::get(device_info()["max_recommended_working_set_size"])) { + std::get(device_info().at("max_recommended_working_set_size"))) { throw std::invalid_argument( "[metal::set_wired_limit] Setting a wired limit larger than " "the maximum working set size is not allowed."); diff --git a/mlx/backend/metal/allocator.h b/mlx/backend/metal/allocator.h index 4e662ab56..df301f55e 100644 --- a/mlx/backend/metal/allocator.h +++ b/mlx/backend/metal/allocator.h @@ -43,6 +43,7 @@ class BufferCache { void remove_from_list(BufferHolder* to_remove); MTL::Device* device_; + MTL::Heap* heap_{nullptr}; std::multimap buffer_pool_; BufferHolder* head_; @@ -78,7 +79,15 @@ class MetalAllocator : public allocator::Allocator { private: MTL::Device* device_; + + // The size of allocations which go on the heap until it is full. This size + // is chosen because it is the actual minimum size of a buffer allocated from + // the heap, a heap can have at most heap.size() / 256 buffers. + static constexpr int small_size_ = 256; + static constexpr int heap_size_ = 1 << 20; + MTL::Heap* heap_; MetalAllocator(); + ~MetalAllocator(); friend MetalAllocator& allocator(); // Caching allocator diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 7a82e2fb3..06681c458 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -692,12 +692,13 @@ void new_stream(Stream stream) { } } -std::unordered_map> +const std::unordered_map>& device_info() { auto init_device_info = []() -> std::unordered_map> { auto pool = new_scoped_memory_pool(); auto raw_device = device(default_device()).mtl_device(); + auto name = std::string(raw_device->name()->utf8String()); auto arch = std::string(raw_device->architecture()->name()->utf8String()); size_t memsize = 0; @@ -711,6 +712,7 @@ device_info() { } return { + {"device_name", name}, {"architecture", arch}, {"max_buffer_length", raw_device->maxBufferLength()}, {"max_recommended_working_set_size", diff --git a/mlx/backend/metal/metal.h b/mlx/backend/metal/metal.h index e5cb65afd..d5c64f79d 100644 --- a/mlx/backend/metal/metal.h +++ b/mlx/backend/metal/metal.h @@ -82,7 +82,7 @@ void start_capture(std::string path = ""); void stop_capture(); /** Get information about the GPU and system settings. */ -std::unordered_map> +const std::unordered_map>& device_info(); } // namespace mlx::core::metal diff --git a/mlx/backend/no_metal/metal.cpp b/mlx/backend/no_metal/metal.cpp index d23f8d33a..9ae9800a2 100644 --- a/mlx/backend/no_metal/metal.cpp +++ b/mlx/backend/no_metal/metal.cpp @@ -54,7 +54,7 @@ void start_capture(std::string) {} void stop_capture() {} void clear_cache() {} -std::unordered_map> +const std::unordered_map>& device_info() { throw std::runtime_error( "[metal::device_info] Cannot get device info without metal backend");