mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
fences must exit
This commit is contained in:
parent
c4230747a1
commit
3ad9031a7f
@ -1,12 +1,39 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
#include "mlx/fence.h"
|
#include <csignal>
|
||||||
|
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
#include "mlx/backend/metal/metal_impl.h"
|
#include "mlx/backend/metal/metal_impl.h"
|
||||||
|
#include "mlx/fence.h"
|
||||||
#include "mlx/scheduler.h"
|
#include "mlx/scheduler.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
void signal_handler(int signum);
|
||||||
|
|
||||||
|
MTL::Buffer* signal_buffer() {
|
||||||
|
auto init = []() {
|
||||||
|
signal(SIGTERM, signal_handler);
|
||||||
|
auto dtor = [](void* buf) {
|
||||||
|
allocator::free(static_cast<MTL::Buffer*>(buf));
|
||||||
|
};
|
||||||
|
auto buf = std::shared_ptr<void>(
|
||||||
|
allocator::malloc_or_wait(sizeof(uint32_t)).ptr(), dtor);
|
||||||
|
static_cast<uint32_t*>(
|
||||||
|
static_cast<MTL::Buffer*>(buf.get())->contents())[0] = 0;
|
||||||
|
return buf;
|
||||||
|
};
|
||||||
|
static std::shared_ptr<void> buf = init();
|
||||||
|
return static_cast<MTL::Buffer*>(buf.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
void signal_handler(int signum) {
|
||||||
|
auto buf = signal_buffer();
|
||||||
|
static_cast<std::atomic_uint*>(buf->contents())[0] = 1;
|
||||||
|
signal(signum, SIG_DFL);
|
||||||
|
raise(signum);
|
||||||
|
}
|
||||||
|
|
||||||
struct FenceImpl {
|
struct FenceImpl {
|
||||||
FenceImpl() {
|
FenceImpl() {
|
||||||
auto d = metal::device(Device::gpu).mtl_device();
|
auto d = metal::device(Device::gpu).mtl_device();
|
||||||
@ -94,6 +121,7 @@ void Fence::wait(Stream stream, const array& x) {
|
|||||||
auto buf = static_cast<MTL::Buffer*>(f.fence);
|
auto buf = static_cast<MTL::Buffer*>(f.fence);
|
||||||
compute_encoder.set_buffer(buf, 0);
|
compute_encoder.set_buffer(buf, 0);
|
||||||
compute_encoder.set_bytes(f.count, 1);
|
compute_encoder.set_bytes(f.count, 1);
|
||||||
|
compute_encoder.set_buffer(signal_buffer(), 2);
|
||||||
compute_encoder.dispatch_threads(kernel_dims, kernel_dims);
|
compute_encoder.dispatch_threads(kernel_dims, kernel_dims);
|
||||||
|
|
||||||
d.get_command_buffer(idx)->addCompletedHandler(
|
d.get_command_buffer(idx)->addCompletedHandler(
|
||||||
|
@ -39,13 +39,14 @@ constexpr constant metal::thread_scope thread_scope_system =
|
|||||||
// single thread kernel to spin wait for timestamp value
|
// single thread kernel to spin wait for timestamp value
|
||||||
[[kernel]] void fence_wait(
|
[[kernel]] void fence_wait(
|
||||||
volatile coherent(system) device uint* timestamp [[buffer(0)]],
|
volatile coherent(system) device uint* timestamp [[buffer(0)]],
|
||||||
constant uint& value [[buffer(1)]]) {
|
constant uint& value [[buffer(1)]],
|
||||||
|
volatile coherent(system) device uint* sig_handler [[buffer(2)]]) {
|
||||||
while (1) {
|
while (1) {
|
||||||
metal::atomic_thread_fence(
|
metal::atomic_thread_fence(
|
||||||
metal::mem_flags::mem_device,
|
metal::mem_flags::mem_device,
|
||||||
metal::memory_order_seq_cst,
|
metal::memory_order_seq_cst,
|
||||||
metal::thread_scope_system);
|
metal::thread_scope_system);
|
||||||
if (timestamp[0] >= value) {
|
if (timestamp[0] >= value || sig_handler[0] >= 0) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user