mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
fences must exit
This commit is contained in:
parent
c4230747a1
commit
3ad9031a7f
@ -1,12 +1,39 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#include "mlx/fence.h"
|
||||
#include <csignal>
|
||||
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/metal_impl.h"
|
||||
#include "mlx/fence.h"
|
||||
#include "mlx/scheduler.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void signal_handler(int signum);
|
||||
|
||||
MTL::Buffer* signal_buffer() {
|
||||
auto init = []() {
|
||||
signal(SIGTERM, signal_handler);
|
||||
auto dtor = [](void* buf) {
|
||||
allocator::free(static_cast<MTL::Buffer*>(buf));
|
||||
};
|
||||
auto buf = std::shared_ptr<void>(
|
||||
allocator::malloc_or_wait(sizeof(uint32_t)).ptr(), dtor);
|
||||
static_cast<uint32_t*>(
|
||||
static_cast<MTL::Buffer*>(buf.get())->contents())[0] = 0;
|
||||
return buf;
|
||||
};
|
||||
static std::shared_ptr<void> buf = init();
|
||||
return static_cast<MTL::Buffer*>(buf.get());
|
||||
}
|
||||
|
||||
void signal_handler(int signum) {
|
||||
auto buf = signal_buffer();
|
||||
static_cast<std::atomic_uint*>(buf->contents())[0] = 1;
|
||||
signal(signum, SIG_DFL);
|
||||
raise(signum);
|
||||
}
|
||||
|
||||
struct FenceImpl {
|
||||
FenceImpl() {
|
||||
auto d = metal::device(Device::gpu).mtl_device();
|
||||
@ -94,6 +121,7 @@ void Fence::wait(Stream stream, const array& x) {
|
||||
auto buf = static_cast<MTL::Buffer*>(f.fence);
|
||||
compute_encoder.set_buffer(buf, 0);
|
||||
compute_encoder.set_bytes(f.count, 1);
|
||||
compute_encoder.set_buffer(signal_buffer(), 2);
|
||||
compute_encoder.dispatch_threads(kernel_dims, kernel_dims);
|
||||
|
||||
d.get_command_buffer(idx)->addCompletedHandler(
|
||||
|
@ -39,13 +39,14 @@ constexpr constant metal::thread_scope thread_scope_system =
|
||||
// single thread kernel to spin wait for timestamp value
|
||||
[[kernel]] void fence_wait(
|
||||
volatile coherent(system) device uint* timestamp [[buffer(0)]],
|
||||
constant uint& value [[buffer(1)]]) {
|
||||
constant uint& value [[buffer(1)]],
|
||||
volatile coherent(system) device uint* sig_handler [[buffer(2)]]) {
|
||||
while (1) {
|
||||
metal::atomic_thread_fence(
|
||||
metal::mem_flags::mem_device,
|
||||
metal::memory_order_seq_cst,
|
||||
metal::thread_scope_system);
|
||||
if (timestamp[0] >= value) {
|
||||
if (timestamp[0] >= value || sig_handler[0] >= 0) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user