fences must exit

This commit is contained in:
Awni Hannun 2025-03-07 09:26:07 -08:00
parent c4230747a1
commit 3ad9031a7f
2 changed files with 32 additions and 3 deletions

View File

@ -1,12 +1,39 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
#include "mlx/fence.h" #include <csignal>
#include "mlx/backend/metal/device.h" #include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/metal_impl.h" #include "mlx/backend/metal/metal_impl.h"
#include "mlx/fence.h"
#include "mlx/scheduler.h" #include "mlx/scheduler.h"
#include "mlx/utils.h" #include "mlx/utils.h"
namespace mlx::core { namespace mlx::core {
void signal_handler(int signum);
MTL::Buffer* signal_buffer() {
auto init = []() {
signal(SIGTERM, signal_handler);
auto dtor = [](void* buf) {
allocator::free(static_cast<MTL::Buffer*>(buf));
};
auto buf = std::shared_ptr<void>(
allocator::malloc_or_wait(sizeof(uint32_t)).ptr(), dtor);
static_cast<uint32_t*>(
static_cast<MTL::Buffer*>(buf.get())->contents())[0] = 0;
return buf;
};
static std::shared_ptr<void> buf = init();
return static_cast<MTL::Buffer*>(buf.get());
}
void signal_handler(int signum) {
auto buf = signal_buffer();
static_cast<std::atomic_uint*>(buf->contents())[0] = 1;
signal(signum, SIG_DFL);
raise(signum);
}
struct FenceImpl { struct FenceImpl {
FenceImpl() { FenceImpl() {
auto d = metal::device(Device::gpu).mtl_device(); auto d = metal::device(Device::gpu).mtl_device();
@ -94,6 +121,7 @@ void Fence::wait(Stream stream, const array& x) {
auto buf = static_cast<MTL::Buffer*>(f.fence); auto buf = static_cast<MTL::Buffer*>(f.fence);
compute_encoder.set_buffer(buf, 0); compute_encoder.set_buffer(buf, 0);
compute_encoder.set_bytes(f.count, 1); compute_encoder.set_bytes(f.count, 1);
compute_encoder.set_buffer(signal_buffer(), 2);
compute_encoder.dispatch_threads(kernel_dims, kernel_dims); compute_encoder.dispatch_threads(kernel_dims, kernel_dims);
d.get_command_buffer(idx)->addCompletedHandler( d.get_command_buffer(idx)->addCompletedHandler(

View File

@ -39,13 +39,14 @@ constexpr constant metal::thread_scope thread_scope_system =
// single thread kernel to spin wait for timestamp value // single thread kernel to spin wait for timestamp value
[[kernel]] void fence_wait( [[kernel]] void fence_wait(
volatile coherent(system) device uint* timestamp [[buffer(0)]], volatile coherent(system) device uint* timestamp [[buffer(0)]],
constant uint& value [[buffer(1)]]) { constant uint& value [[buffer(1)]],
volatile coherent(system) device uint* sig_handler [[buffer(2)]]) {
while (1) { while (1) {
metal::atomic_thread_fence( metal::atomic_thread_fence(
metal::mem_flags::mem_device, metal::mem_flags::mem_device,
metal::memory_order_seq_cst, metal::memory_order_seq_cst,
metal::thread_scope_system); metal::thread_scope_system);
if (timestamp[0] >= value) { if (timestamp[0] >= value || sig_handler[0] >= 0) {
break; break;
} }
} }