mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Revisit autorelease memory pools (#260)
* make general autorelease pool part of metal device * make things simpler * no metal backend support * new_memory_pool -> new_scoped_memory_pool
This commit is contained in:
		| @@ -17,10 +17,11 @@ namespace fs = std::filesystem; | ||||
|  | ||||
| namespace mlx::core::metal { | ||||
|  | ||||
| static Device metal_device_; | ||||
|  | ||||
| namespace { | ||||
|  | ||||
| // Catch things related to the main-thread static variables | ||||
| static std::shared_ptr<void> global_memory_pool = new_scoped_memory_pool(); | ||||
|  | ||||
| // TODO nicer way to set this or possibly expose as an environment variable | ||||
| static constexpr int MAX_BUFFERS_PER_QUEUE = 12; | ||||
|  | ||||
| @@ -112,29 +113,29 @@ MTL::Library* load_library( | ||||
|  | ||||
| } // namespace | ||||
|  | ||||
| Device::Device() | ||||
|     : pool_(NS::AutoreleasePool::alloc()->init()), | ||||
|       device_(load_device()), | ||||
|       library_map_({{"mlx", load_library(device_)}}) {} | ||||
| Device::Device() { | ||||
|   auto pool = new_scoped_memory_pool(); | ||||
|   device_ = load_device(); | ||||
|   library_map_ = {{"mlx", load_library(device_)}}; | ||||
| } | ||||
|  | ||||
| Device::~Device() { | ||||
|   for (auto& q : queue_map_) { | ||||
|     q.second->release(); | ||||
|   } | ||||
|   for (auto& k : kernel_map_) { | ||||
|     k.second->release(); | ||||
|   } | ||||
|   for (auto& l : library_map_) { | ||||
|     l.second->release(); | ||||
|   } | ||||
|   for (auto& b : buffer_map_) { | ||||
|     b.second.second->release(); | ||||
|   } | ||||
|   for (auto& e : encoder_map_) { | ||||
|     e.second->release(); | ||||
|   } | ||||
|   for (auto& k : kernel_map_) { | ||||
|     k.second->release(); | ||||
|   } | ||||
|   for (auto& l : library_map_) { | ||||
|     l.second->release(); | ||||
|   } | ||||
|   device_->release(); | ||||
|   pool_->release(); | ||||
| } | ||||
|  | ||||
| void Device::new_queue(int index) { | ||||
| @@ -243,6 +244,7 @@ void Device::register_library( | ||||
| MTL::ComputePipelineState* Device::get_kernel( | ||||
|     const std::string& name, | ||||
|     const std::string& lib_name /* = "mlx" */) { | ||||
|   auto pool = new_scoped_memory_pool(); | ||||
|   // Look for cached kernel | ||||
|   if (auto it = kernel_map_.find(name); it != kernel_map_.end()) { | ||||
|     return it->second; | ||||
| @@ -285,17 +287,18 @@ MTL::ComputePipelineState* Device::get_kernel( | ||||
| } | ||||
|  | ||||
| Device& device(mlx::core::Device) { | ||||
|   return metal_device_; | ||||
|   static Device metal_device; | ||||
|   return metal_device; | ||||
| } | ||||
|  | ||||
| NS::AutoreleasePool*& thread_autorelease_pool() { | ||||
|   static thread_local NS::AutoreleasePool* p = | ||||
|       NS::AutoreleasePool::alloc()->init(); | ||||
|   return p; | ||||
| std::shared_ptr<void> new_scoped_memory_pool() { | ||||
|   auto dtor = [](void* ptr) { | ||||
|     static_cast<NS::AutoreleasePool*>(ptr)->release(); | ||||
|   }; | ||||
|   return std::shared_ptr<void>(NS::AutoreleasePool::alloc()->init(), dtor); | ||||
| } | ||||
|  | ||||
| void new_stream(Stream stream) { | ||||
|   thread_autorelease_pool(); | ||||
|   if (stream.device == mlx::core::Device::gpu) { | ||||
|     device(stream.device).new_queue(stream.index); | ||||
|   } | ||||
|   | ||||
| @@ -67,7 +67,6 @@ class Device { | ||||
|       const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const; | ||||
|  | ||||
|  private: | ||||
|   NS::AutoreleasePool* pool_; | ||||
|   MTL::Device* device_; | ||||
|   std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_; | ||||
|   std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_; | ||||
| @@ -78,6 +77,5 @@ class Device { | ||||
| }; | ||||
|  | ||||
| Device& device(mlx::core::Device); | ||||
| NS::AutoreleasePool*& thread_autorelease_pool(); | ||||
|  | ||||
| } // namespace mlx::core::metal | ||||
|   | ||||
| @@ -50,6 +50,7 @@ std::function<void()> make_task( | ||||
|     bool retain_graph) { | ||||
|   auto task = | ||||
|       [retain_graph, arr, deps = std::move(deps), p = std::move(p)]() mutable { | ||||
|         auto pool = new_scoped_memory_pool(); | ||||
|         for (auto& d : deps) { | ||||
|           d.wait(); | ||||
|         } | ||||
| @@ -66,12 +67,6 @@ std::function<void()> make_task( | ||||
|                   arr.detach(); | ||||
|                 } | ||||
|                 p->set_value(); | ||||
|                 // Signal this thread to clear the pool on a synchroniztion. | ||||
|                 scheduler::enqueue(s, []() { | ||||
|                   thread_autorelease_pool()->release(); | ||||
|                   thread_autorelease_pool() = | ||||
|                       NS::AutoreleasePool::alloc()->init(); | ||||
|                 }); | ||||
|                 scheduler::notify_task_completion(s); | ||||
|               }); | ||||
|           metal::device(s.device).commit_command_buffer(s.index); | ||||
|   | ||||
| @@ -20,6 +20,7 @@ constexpr bool is_available() { | ||||
| } | ||||
|  | ||||
| void new_stream(Stream stream); | ||||
| std::shared_ptr<void> new_scoped_memory_pool(); | ||||
|  | ||||
| std::function<void()> make_task( | ||||
|     array& arr, | ||||
|   | ||||
| @@ -7,6 +7,9 @@ | ||||
| namespace mlx::core::metal { | ||||
|  | ||||
| void new_stream(Stream) {} | ||||
| std::shared_ptr<void> new_memory_pool() { | ||||
|   return nullptr; | ||||
| } | ||||
|  | ||||
| std::function<void()> make_task( | ||||
|     array& arr, | ||||
|   | ||||
| @@ -35,6 +35,7 @@ struct StreamThread { | ||||
|   } | ||||
|  | ||||
|   void thread_fn() { | ||||
|     auto thread_pool = metal::new_scoped_memory_pool(); | ||||
|     metal::new_stream(stream); | ||||
|     while (true) { | ||||
|       std::function<void()> task; | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Ronan Collobert
					Ronan Collobert