diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 6b6158f29..a89cf8d4e 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -288,15 +288,15 @@ Device& device(mlx::core::Device) { return metal_device_; } -NS::AutoreleasePool*& thread_autorelease_pool() { +NS::AutoreleasePool*& Device::g_thread_autorelease_pool() { static thread_local NS::AutoreleasePool* p = NS::AutoreleasePool::alloc()->init(); return p; } void new_stream(Stream stream) { - thread_autorelease_pool(); if (stream.device == mlx::core::Device::gpu) { + device(stream.device).g_thread_autorelease_pool(); device(stream.device).new_queue(stream.index); } } diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index 62675d430..96d537a26 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -66,6 +66,8 @@ class Device { MTL::ArgumentEncoder* argument_encoder( const std::vector& arg_descs) const; + NS::AutoreleasePool*& g_thread_autorelease_pool(); + private: NS::AutoreleasePool* pool_; MTL::Device* device_; @@ -78,6 +80,5 @@ class Device { }; Device& device(mlx::core::Device); -NS::AutoreleasePool*& thread_autorelease_pool(); } // namespace mlx::core::metal diff --git a/mlx/backend/metal/metal.cpp b/mlx/backend/metal/metal.cpp index f63ad55a3..0a825df25 100644 --- a/mlx/backend/metal/metal.cpp +++ b/mlx/backend/metal/metal.cpp @@ -48,42 +48,44 @@ std::function make_task( std::vector> deps, std::shared_ptr> p, bool retain_graph) { - auto task = - [retain_graph, arr, deps = std::move(deps), p = std::move(p)]() mutable { - for (auto& d : deps) { - d.wait(); - } - auto s = arr.primitive().stream(); - auto command_buffer = increment_command_buffer(s); - arr.primitive().eval_gpu(arr.inputs(), arr); - if (p) { - metal::device(s.device).end_encoding(s.index); - scheduler::notify_new_task(s); - command_buffer->addCompletedHandler( - [retain_graph, s, arr, p = std::move(p)]( - MTL::CommandBuffer*) mutable { - if (!retain_graph) { - arr.detach(); - } - p->set_value(); - // Signal this thread to clear the pool on a synchroniztion. - scheduler::enqueue(s, []() { - thread_autorelease_pool()->release(); - thread_autorelease_pool() = - NS::AutoreleasePool::alloc()->init(); - }); - scheduler::notify_task_completion(s); - }); - metal::device(s.device).commit_command_buffer(s.index); - } else { - command_buffer->addCompletedHandler( - [retain_graph, s, arr](MTL::CommandBuffer*) mutable { - if (!retain_graph) { - arr.detach(); - } - }); - } - }; + auto task = [retain_graph, + arr, + deps = std::move(deps), + p = std::move(p)]() mutable { + for (auto& d : deps) { + d.wait(); + } + auto s = arr.primitive().stream(); + auto command_buffer = increment_command_buffer(s); + arr.primitive().eval_gpu(arr.inputs(), arr); + if (p) { + metal::device(s.device).end_encoding(s.index); + scheduler::notify_new_task(s); + command_buffer->addCompletedHandler( + [retain_graph, s, arr, p = std::move(p)]( + MTL::CommandBuffer*) mutable { + if (!retain_graph) { + arr.detach(); + } + p->set_value(); + // Signal this thread to clear the pool on a synchroniztion. + scheduler::enqueue(s, [s]() { + metal::device(s.device).g_thread_autorelease_pool()->release(); + metal::device(s.device).g_thread_autorelease_pool() = + NS::AutoreleasePool::alloc()->init(); + }); + scheduler::notify_task_completion(s); + }); + metal::device(s.device).commit_command_buffer(s.index); + } else { + command_buffer->addCompletedHandler( + [retain_graph, s, arr](MTL::CommandBuffer*) mutable { + if (!retain_graph) { + arr.detach(); + } + }); + } + }; return task; }