mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Faster synchronization Fence primitive (#1773)
* try faster synchronization move event fixes update bench fix fix * non-functioning kernel * try alternative fence * cleanup barrier * get rid of event_fence * update benchmarks * doc string in metal fence
This commit is contained in:
@@ -45,6 +45,9 @@ build_kernel(random)
|
||||
build_kernel(rms_norm)
|
||||
build_kernel(rope)
|
||||
build_kernel(scaled_dot_product_attention sdpa_vector.h)
|
||||
if(MLX_METAL_VERSION GREATER_EQUAL 320)
|
||||
build_kernel(fence)
|
||||
endif()
|
||||
|
||||
set(STEEL_HEADERS
|
||||
steel/defines.h
|
||||
|
||||
52
mlx/backend/metal/kernels/fence.metal
Normal file
52
mlx/backend/metal/kernels/fence.metal
Normal file
@@ -0,0 +1,52 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma METAL internals : enable
|
||||
|
||||
#ifndef __METAL_MEMORY_SCOPE_SYSTEM__
|
||||
#define __METAL_MEMORY_SCOPE_SYSTEM__ 3
|
||||
namespace metal {
|
||||
constexpr constant metal::thread_scope thread_scope_system =
|
||||
static_cast<thread_scope>(__METAL_MEMORY_SCOPE_SYSTEM__);
|
||||
}
|
||||
#endif
|
||||
|
||||
#include <metal_atomic>
|
||||
|
||||
[[kernel]] void input_coherent(
|
||||
volatile coherent(system) device uint* input [[buffer(0)]],
|
||||
const constant uint& size [[buffer(1)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
if (index < size) {
|
||||
input[index] = input[index];
|
||||
}
|
||||
metal::atomic_thread_fence(
|
||||
metal::mem_flags::mem_device,
|
||||
metal::memory_order_seq_cst,
|
||||
metal::thread_scope_system);
|
||||
}
|
||||
|
||||
// single thread kernel to update timestamp
|
||||
[[kernel]] void fence_update(
|
||||
volatile coherent(system) device uint* timestamp [[buffer(0)]],
|
||||
constant uint& value [[buffer(1)]]) {
|
||||
timestamp[0] = value;
|
||||
metal::atomic_thread_fence(
|
||||
metal::mem_flags::mem_device,
|
||||
metal::memory_order_seq_cst,
|
||||
metal::thread_scope_system);
|
||||
}
|
||||
|
||||
// single thread kernel to spin wait for timestamp value
|
||||
[[kernel]] void fence_wait(
|
||||
volatile coherent(system) device uint* timestamp [[buffer(0)]],
|
||||
constant uint& value [[buffer(1)]]) {
|
||||
while (1) {
|
||||
metal::atomic_thread_fence(
|
||||
metal::mem_flags::mem_device,
|
||||
metal::memory_order_seq_cst,
|
||||
metal::thread_scope_system);
|
||||
if (timestamp[0] >= value) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user