From cd3616a463a594a1679bd9dc823071bed1469c51 Mon Sep 17 00:00:00 2001 From: Ronan Collobert Date: Fri, 22 Dec 2023 11:01:26 -0800 Subject: [PATCH] Revisit autorelease memory pools (#260) * make general autorelease pool part of metal device * make things simpler * no metal backend support * new_memory_pool -> new_scoped_memory_pool --- mlx/backend/metal/device.cpp | 41 ++++++++++++++++++---------------- mlx/backend/metal/device.h | 2 -- mlx/backend/metal/metal.cpp | 7 +----- mlx/backend/metal/metal.h | 1 + mlx/backend/no_metal/metal.cpp | 3 +++ mlx/scheduler.h | 1 + 6 files changed, 28 insertions(+), 27 deletions(-) diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 6b6158f29..c48f2908f 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -17,10 +17,11 @@ namespace fs = std::filesystem; namespace mlx::core::metal { -static Device metal_device_; - namespace { +// Catch things related to the main-thread static variables +static std::shared_ptr global_memory_pool = new_scoped_memory_pool(); + // TODO nicer way to set this or possibly expose as an environment variable static constexpr int MAX_BUFFERS_PER_QUEUE = 12; @@ -112,29 +113,29 @@ MTL::Library* load_library( } // namespace -Device::Device() - : pool_(NS::AutoreleasePool::alloc()->init()), - device_(load_device()), - library_map_({{"mlx", load_library(device_)}}) {} +Device::Device() { + auto pool = new_scoped_memory_pool(); + device_ = load_device(); + library_map_ = {{"mlx", load_library(device_)}}; +} Device::~Device() { for (auto& q : queue_map_) { q.second->release(); } - for (auto& k : kernel_map_) { - k.second->release(); - } - for (auto& l : library_map_) { - l.second->release(); - } for (auto& b : buffer_map_) { b.second.second->release(); } for (auto& e : encoder_map_) { e.second->release(); } + for (auto& k : kernel_map_) { + k.second->release(); + } + for (auto& l : library_map_) { + l.second->release(); + } device_->release(); - pool_->release(); } void Device::new_queue(int index) { @@ -243,6 +244,7 @@ void Device::register_library( MTL::ComputePipelineState* Device::get_kernel( const std::string& name, const std::string& lib_name /* = "mlx" */) { + auto pool = new_scoped_memory_pool(); // Look for cached kernel if (auto it = kernel_map_.find(name); it != kernel_map_.end()) { return it->second; @@ -285,17 +287,18 @@ MTL::ComputePipelineState* Device::get_kernel( } Device& device(mlx::core::Device) { - return metal_device_; + static Device metal_device; + return metal_device; } -NS::AutoreleasePool*& thread_autorelease_pool() { - static thread_local NS::AutoreleasePool* p = - NS::AutoreleasePool::alloc()->init(); - return p; +std::shared_ptr new_scoped_memory_pool() { + auto dtor = [](void* ptr) { + static_cast(ptr)->release(); + }; + return std::shared_ptr(NS::AutoreleasePool::alloc()->init(), dtor); } void new_stream(Stream stream) { - thread_autorelease_pool(); if (stream.device == mlx::core::Device::gpu) { device(stream.device).new_queue(stream.index); } diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index 62675d430..45449a332 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -67,7 +67,6 @@ class Device { const std::vector& arg_descs) const; private: - NS::AutoreleasePool* pool_; MTL::Device* device_; std::unordered_map queue_map_; std::unordered_map> buffer_map_; @@ -78,6 +77,5 @@ class Device { }; Device& device(mlx::core::Device); -NS::AutoreleasePool*& thread_autorelease_pool(); } // namespace mlx::core::metal diff --git a/mlx/backend/metal/metal.cpp b/mlx/backend/metal/metal.cpp index f63ad55a3..478e57c73 100644 --- a/mlx/backend/metal/metal.cpp +++ b/mlx/backend/metal/metal.cpp @@ -50,6 +50,7 @@ std::function make_task( bool retain_graph) { auto task = [retain_graph, arr, deps = std::move(deps), p = std::move(p)]() mutable { + auto pool = new_scoped_memory_pool(); for (auto& d : deps) { d.wait(); } @@ -66,12 +67,6 @@ std::function make_task( arr.detach(); } p->set_value(); - // Signal this thread to clear the pool on a synchroniztion. - scheduler::enqueue(s, []() { - thread_autorelease_pool()->release(); - thread_autorelease_pool() = - NS::AutoreleasePool::alloc()->init(); - }); scheduler::notify_task_completion(s); }); metal::device(s.device).commit_command_buffer(s.index); diff --git a/mlx/backend/metal/metal.h b/mlx/backend/metal/metal.h index f1f7ede44..99f400956 100644 --- a/mlx/backend/metal/metal.h +++ b/mlx/backend/metal/metal.h @@ -20,6 +20,7 @@ constexpr bool is_available() { } void new_stream(Stream stream); +std::shared_ptr new_scoped_memory_pool(); std::function make_task( array& arr, diff --git a/mlx/backend/no_metal/metal.cpp b/mlx/backend/no_metal/metal.cpp index accfc4c8a..b3a7dc41c 100644 --- a/mlx/backend/no_metal/metal.cpp +++ b/mlx/backend/no_metal/metal.cpp @@ -7,6 +7,9 @@ namespace mlx::core::metal { void new_stream(Stream) {} +std::shared_ptr new_memory_pool() { + return nullptr; +} std::function make_task( array& arr, diff --git a/mlx/scheduler.h b/mlx/scheduler.h index 6506b20ab..150cc96db 100644 --- a/mlx/scheduler.h +++ b/mlx/scheduler.h @@ -35,6 +35,7 @@ struct StreamThread { } void thread_fn() { + auto thread_pool = metal::new_scoped_memory_pool(); metal::new_stream(stream); while (true) { std::function task;