mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	fences must exit
This commit is contained in:
		| @@ -1,12 +1,39 @@ | ||||
| // Copyright © 2024 Apple Inc. | ||||
| #include "mlx/fence.h" | ||||
| #include <csignal> | ||||
|  | ||||
| #include "mlx/backend/metal/device.h" | ||||
| #include "mlx/backend/metal/metal_impl.h" | ||||
| #include "mlx/fence.h" | ||||
| #include "mlx/scheduler.h" | ||||
| #include "mlx/utils.h" | ||||
|  | ||||
| namespace mlx::core { | ||||
|  | ||||
| void signal_handler(int signum); | ||||
|  | ||||
| MTL::Buffer* signal_buffer() { | ||||
|   auto init = []() { | ||||
|     signal(SIGTERM, signal_handler); | ||||
|     auto dtor = [](void* buf) { | ||||
|       allocator::free(static_cast<MTL::Buffer*>(buf)); | ||||
|     }; | ||||
|     auto buf = std::shared_ptr<void>( | ||||
|         allocator::malloc_or_wait(sizeof(uint32_t)).ptr(), dtor); | ||||
|     static_cast<uint32_t*>( | ||||
|         static_cast<MTL::Buffer*>(buf.get())->contents())[0] = 0; | ||||
|     return buf; | ||||
|   }; | ||||
|   static std::shared_ptr<void> buf = init(); | ||||
|   return static_cast<MTL::Buffer*>(buf.get()); | ||||
| } | ||||
|  | ||||
| void signal_handler(int signum) { | ||||
|   auto buf = signal_buffer(); | ||||
|   static_cast<std::atomic_uint*>(buf->contents())[0] = 1; | ||||
|   signal(signum, SIG_DFL); | ||||
|   raise(signum); | ||||
| } | ||||
|  | ||||
| struct FenceImpl { | ||||
|   FenceImpl() { | ||||
|     auto d = metal::device(Device::gpu).mtl_device(); | ||||
| @@ -94,6 +121,7 @@ void Fence::wait(Stream stream, const array& x) { | ||||
|   auto buf = static_cast<MTL::Buffer*>(f.fence); | ||||
|   compute_encoder.set_buffer(buf, 0); | ||||
|   compute_encoder.set_bytes(f.count, 1); | ||||
|   compute_encoder.set_buffer(signal_buffer(), 2); | ||||
|   compute_encoder.dispatch_threads(kernel_dims, kernel_dims); | ||||
|  | ||||
|   d.get_command_buffer(idx)->addCompletedHandler( | ||||
|   | ||||
| @@ -39,13 +39,14 @@ constexpr constant metal::thread_scope thread_scope_system = | ||||
| // single thread kernel to spin wait for timestamp value | ||||
| [[kernel]] void fence_wait( | ||||
|     volatile coherent(system) device uint* timestamp [[buffer(0)]], | ||||
|     constant uint& value [[buffer(1)]]) { | ||||
|     constant uint& value [[buffer(1)]], | ||||
|     volatile coherent(system) device uint* sig_handler [[buffer(2)]]) { | ||||
|   while (1) { | ||||
|     metal::atomic_thread_fence( | ||||
|         metal::mem_flags::mem_device, | ||||
|         metal::memory_order_seq_cst, | ||||
|         metal::thread_scope_system); | ||||
|     if (timestamp[0] >= value) { | ||||
|     if (timestamp[0] >= value || sig_handler[0] >= 0) { | ||||
|       break; | ||||
|     } | ||||
|   } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun