From d9478d0eb03738e9006b431a31512927ee611357 Mon Sep 17 00:00:00 2001 From: Ronan Collobert Date: Fri, 22 Dec 2023 10:51:49 -0800 Subject: [PATCH] new_memory_pool -> new_scoped_memory_pool --- mlx/backend/metal/device.cpp | 8 ++++---- mlx/backend/metal/metal.cpp | 2 +- mlx/backend/metal/metal.h | 2 +- mlx/scheduler.h | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index ce3e2792f..c48f2908f 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -20,7 +20,7 @@ namespace mlx::core::metal { namespace { // Catch things related to the main-thread static variables -static std::shared_ptr global_memory_pool = new_memory_pool(); +static std::shared_ptr global_memory_pool = new_scoped_memory_pool(); // TODO nicer way to set this or possibly expose as an environment variable static constexpr int MAX_BUFFERS_PER_QUEUE = 12; @@ -114,7 +114,7 @@ MTL::Library* load_library( } // namespace Device::Device() { - auto pool = new_memory_pool(); + auto pool = new_scoped_memory_pool(); device_ = load_device(); library_map_ = {{"mlx", load_library(device_)}}; } @@ -244,7 +244,7 @@ void Device::register_library( MTL::ComputePipelineState* Device::get_kernel( const std::string& name, const std::string& lib_name /* = "mlx" */) { - auto pool = new_memory_pool(); + auto pool = new_scoped_memory_pool(); // Look for cached kernel if (auto it = kernel_map_.find(name); it != kernel_map_.end()) { return it->second; @@ -291,7 +291,7 @@ Device& device(mlx::core::Device) { return metal_device; } -std::shared_ptr new_memory_pool() { +std::shared_ptr new_scoped_memory_pool() { auto dtor = [](void* ptr) { static_cast(ptr)->release(); }; diff --git a/mlx/backend/metal/metal.cpp b/mlx/backend/metal/metal.cpp index 00ccc4541..478e57c73 100644 --- a/mlx/backend/metal/metal.cpp +++ b/mlx/backend/metal/metal.cpp @@ -50,7 +50,7 @@ std::function make_task( bool retain_graph) { auto task = [retain_graph, arr, deps = std::move(deps), p = std::move(p)]() mutable { - auto pool = new_memory_pool(); + auto pool = new_scoped_memory_pool(); for (auto& d : deps) { d.wait(); } diff --git a/mlx/backend/metal/metal.h b/mlx/backend/metal/metal.h index d5ac2d4e6..99f400956 100644 --- a/mlx/backend/metal/metal.h +++ b/mlx/backend/metal/metal.h @@ -20,7 +20,7 @@ constexpr bool is_available() { } void new_stream(Stream stream); -std::shared_ptr new_memory_pool(); +std::shared_ptr new_scoped_memory_pool(); std::function make_task( array& arr, diff --git a/mlx/scheduler.h b/mlx/scheduler.h index d582fba2f..150cc96db 100644 --- a/mlx/scheduler.h +++ b/mlx/scheduler.h @@ -35,7 +35,7 @@ struct StreamThread { } void thread_fn() { - auto thread_pool = metal::new_memory_pool(); + auto thread_pool = metal::new_scoped_memory_pool(); metal::new_stream(stream); while (true) { std::function task;