Faster synchronization Fence primitive (#1773)

* try faster synchronization

move event

fixes

update bench

fix

fix

* non-functioning kernel

* try alternative fence

* cleanup barrier

* get rid of event_fence

* update benchmarks

* doc string in metal fence
This commit is contained in:
Awni Hannun
2025-01-17 18:42:19 -08:00
committed by GitHub
parent 0c259961ac
commit a4667da1eb
11 changed files with 362 additions and 31 deletions

View File

@@ -45,6 +45,9 @@ build_kernel(random)
build_kernel(rms_norm)
build_kernel(rope)
build_kernel(scaled_dot_product_attention sdpa_vector.h)
if(MLX_METAL_VERSION GREATER_EQUAL 320)
build_kernel(fence)
endif()
set(STEEL_HEADERS
steel/defines.h

View File

@@ -0,0 +1,52 @@
// Copyright © 2024 Apple Inc.
#pragma METAL internals : enable
#ifndef __METAL_MEMORY_SCOPE_SYSTEM__
#define __METAL_MEMORY_SCOPE_SYSTEM__ 3
namespace metal {
constexpr constant metal::thread_scope thread_scope_system =
static_cast<thread_scope>(__METAL_MEMORY_SCOPE_SYSTEM__);
}
#endif
#include <metal_atomic>
[[kernel]] void input_coherent(
volatile coherent(system) device uint* input [[buffer(0)]],
const constant uint& size [[buffer(1)]],
uint index [[thread_position_in_grid]]) {
if (index < size) {
input[index] = input[index];
}
metal::atomic_thread_fence(
metal::mem_flags::mem_device,
metal::memory_order_seq_cst,
metal::thread_scope_system);
}
// single thread kernel to update timestamp
[[kernel]] void fence_update(
volatile coherent(system) device uint* timestamp [[buffer(0)]],
constant uint& value [[buffer(1)]]) {
timestamp[0] = value;
metal::atomic_thread_fence(
metal::mem_flags::mem_device,
metal::memory_order_seq_cst,
metal::thread_scope_system);
}
// single thread kernel to spin wait for timestamp value
[[kernel]] void fence_wait(
volatile coherent(system) device uint* timestamp [[buffer(0)]],
constant uint& value [[buffer(1)]]) {
while (1) {
metal::atomic_thread_fence(
metal::mem_flags::mem_device,
metal::memory_order_seq_cst,
metal::thread_scope_system);
if (timestamp[0] >= value) {
break;
}
}
}