From 29a8b2047b19fb6cb685e6c55fd8c3c75969d13e Mon Sep 17 00:00:00 2001 From: Ronan Collobert Date: Thu, 21 Dec 2023 16:22:52 -0800 Subject: [PATCH] make things simpler --- mlx/backend/metal/device.cpp | 41 +++++++++++---------- mlx/backend/metal/device.h | 3 -- mlx/backend/metal/metal.cpp | 69 ++++++++++++++++-------------------- mlx/backend/metal/metal.h | 1 + mlx/scheduler.h | 1 + 5 files changed, 55 insertions(+), 60 deletions(-) diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index a89cf8d4e..ce3e2792f 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -17,10 +17,11 @@ namespace fs = std::filesystem; namespace mlx::core::metal { -static Device metal_device_; - namespace { +// Catch things related to the main-thread static variables +static std::shared_ptr global_memory_pool = new_memory_pool(); + // TODO nicer way to set this or possibly expose as an environment variable static constexpr int MAX_BUFFERS_PER_QUEUE = 12; @@ -112,29 +113,29 @@ MTL::Library* load_library( } // namespace -Device::Device() - : pool_(NS::AutoreleasePool::alloc()->init()), - device_(load_device()), - library_map_({{"mlx", load_library(device_)}}) {} +Device::Device() { + auto pool = new_memory_pool(); + device_ = load_device(); + library_map_ = {{"mlx", load_library(device_)}}; +} Device::~Device() { for (auto& q : queue_map_) { q.second->release(); } - for (auto& k : kernel_map_) { - k.second->release(); - } - for (auto& l : library_map_) { - l.second->release(); - } for (auto& b : buffer_map_) { b.second.second->release(); } for (auto& e : encoder_map_) { e.second->release(); } + for (auto& k : kernel_map_) { + k.second->release(); + } + for (auto& l : library_map_) { + l.second->release(); + } device_->release(); - pool_->release(); } void Device::new_queue(int index) { @@ -243,6 +244,7 @@ void Device::register_library( MTL::ComputePipelineState* Device::get_kernel( const std::string& name, const std::string& lib_name /* = "mlx" */) { + auto pool = new_memory_pool(); // Look for cached kernel if (auto it = kernel_map_.find(name); it != kernel_map_.end()) { return it->second; @@ -285,18 +287,19 @@ MTL::ComputePipelineState* Device::get_kernel( } Device& device(mlx::core::Device) { - return metal_device_; + static Device metal_device; + return metal_device; } -NS::AutoreleasePool*& Device::g_thread_autorelease_pool() { - static thread_local NS::AutoreleasePool* p = - NS::AutoreleasePool::alloc()->init(); - return p; +std::shared_ptr new_memory_pool() { + auto dtor = [](void* ptr) { + static_cast(ptr)->release(); + }; + return std::shared_ptr(NS::AutoreleasePool::alloc()->init(), dtor); } void new_stream(Stream stream) { if (stream.device == mlx::core::Device::gpu) { - device(stream.device).g_thread_autorelease_pool(); device(stream.device).new_queue(stream.index); } } diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index 96d537a26..45449a332 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -66,10 +66,7 @@ class Device { MTL::ArgumentEncoder* argument_encoder( const std::vector& arg_descs) const; - NS::AutoreleasePool*& g_thread_autorelease_pool(); - private: - NS::AutoreleasePool* pool_; MTL::Device* device_; std::unordered_map queue_map_; std::unordered_map> buffer_map_; diff --git a/mlx/backend/metal/metal.cpp b/mlx/backend/metal/metal.cpp index 0a825df25..00ccc4541 100644 --- a/mlx/backend/metal/metal.cpp +++ b/mlx/backend/metal/metal.cpp @@ -48,44 +48,37 @@ std::function make_task( std::vector> deps, std::shared_ptr> p, bool retain_graph) { - auto task = [retain_graph, - arr, - deps = std::move(deps), - p = std::move(p)]() mutable { - for (auto& d : deps) { - d.wait(); - } - auto s = arr.primitive().stream(); - auto command_buffer = increment_command_buffer(s); - arr.primitive().eval_gpu(arr.inputs(), arr); - if (p) { - metal::device(s.device).end_encoding(s.index); - scheduler::notify_new_task(s); - command_buffer->addCompletedHandler( - [retain_graph, s, arr, p = std::move(p)]( - MTL::CommandBuffer*) mutable { - if (!retain_graph) { - arr.detach(); - } - p->set_value(); - // Signal this thread to clear the pool on a synchroniztion. - scheduler::enqueue(s, [s]() { - metal::device(s.device).g_thread_autorelease_pool()->release(); - metal::device(s.device).g_thread_autorelease_pool() = - NS::AutoreleasePool::alloc()->init(); - }); - scheduler::notify_task_completion(s); - }); - metal::device(s.device).commit_command_buffer(s.index); - } else { - command_buffer->addCompletedHandler( - [retain_graph, s, arr](MTL::CommandBuffer*) mutable { - if (!retain_graph) { - arr.detach(); - } - }); - } - }; + auto task = + [retain_graph, arr, deps = std::move(deps), p = std::move(p)]() mutable { + auto pool = new_memory_pool(); + for (auto& d : deps) { + d.wait(); + } + auto s = arr.primitive().stream(); + auto command_buffer = increment_command_buffer(s); + arr.primitive().eval_gpu(arr.inputs(), arr); + if (p) { + metal::device(s.device).end_encoding(s.index); + scheduler::notify_new_task(s); + command_buffer->addCompletedHandler( + [retain_graph, s, arr, p = std::move(p)]( + MTL::CommandBuffer*) mutable { + if (!retain_graph) { + arr.detach(); + } + p->set_value(); + scheduler::notify_task_completion(s); + }); + metal::device(s.device).commit_command_buffer(s.index); + } else { + command_buffer->addCompletedHandler( + [retain_graph, s, arr](MTL::CommandBuffer*) mutable { + if (!retain_graph) { + arr.detach(); + } + }); + } + }; return task; } diff --git a/mlx/backend/metal/metal.h b/mlx/backend/metal/metal.h index f1f7ede44..d5ac2d4e6 100644 --- a/mlx/backend/metal/metal.h +++ b/mlx/backend/metal/metal.h @@ -20,6 +20,7 @@ constexpr bool is_available() { } void new_stream(Stream stream); +std::shared_ptr new_memory_pool(); std::function make_task( array& arr, diff --git a/mlx/scheduler.h b/mlx/scheduler.h index 6506b20ab..d582fba2f 100644 --- a/mlx/scheduler.h +++ b/mlx/scheduler.h @@ -35,6 +35,7 @@ struct StreamThread { } void thread_fn() { + auto thread_pool = metal::new_memory_pool(); metal::new_stream(stream); while (true) { std::function task;