Faster synchronization Fence primitive (#1773)

* try faster synchronization

move event

fixes

update bench

fix

fix

* non-functioning kernel

* try alternative fence

* cleanup barrier

* get rid of event_fence

* update benchmarks

* doc string in metal fence
This commit is contained in:
Awni Hannun 2025-01-17 18:42:19 -08:00 committed by GitHub
parent 0c259961ac
commit a4667da1eb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 362 additions and 31 deletions

View File

@ -0,0 +1,55 @@
import time
import mlx.core as mx
rank = mx.distributed.init().rank()
def timeit(fn, a):
# warmup
for _ in range(5):
mx.eval(fn(a))
its = 10
tic = time.perf_counter()
for _ in range(its):
mx.eval(fn(a))
toc = time.perf_counter()
ms = 1000 * (toc - tic) / its
return ms
def all_reduce_benchmark():
a = mx.ones((5, 5), mx.int32)
its_per_eval = 100
def fn(x):
for _ in range(its_per_eval):
x = mx.distributed.all_sum(x)
x = x - 1
return x
ms = timeit(fn, a) / its_per_eval
if rank == 0:
print(f"All Reduce: time per iteration {ms:.6f} (ms)")
def all_gather_benchmark():
a = mx.ones((5, 5), mx.int32)
its_per_eval = 100
def fn(x):
for _ in range(its_per_eval):
x = mx.distributed.all_gather(x)[0]
return x
ms = timeit(fn, a) / its_per_eval
if rank == 0:
print(f"All gather: time per iteration {ms:.6f} (ms)")
if __name__ == "__main__":
all_reduce_benchmark()
all_gather_benchmark()

View File

@ -89,6 +89,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp

View File

@ -134,6 +134,13 @@ CommandEncoder::~CommandEncoder() {
enc_->release();
}
void CommandEncoder::set_buffer(
const MTL::Buffer* buf,
int idx,
int64_t offset /* = 0 */) {
enc_->setBuffer(buf, offset, idx);
}
void CommandEncoder::set_input_array(
const array& a,
int idx,
@ -155,6 +162,10 @@ void CommandEncoder::set_output_array(
int64_t offset /* = 0 */) {
// Add barriers before adding the output to the output set
set_input_array(a, idx, offset);
register_output_array(a);
}
void CommandEncoder::register_output_array(array& a) {
all_outputs_.insert(a.buffer().ptr());
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
if (concurrent_) {
@ -189,6 +200,10 @@ void CommandEncoder::dispatch_threads(
enc_->dispatchThreads(grid_dims, group_dims);
}
void CommandEncoder::barrier() {
enc_->memoryBarrier(MTL::BarrierScopeBuffers);
}
Device::Device() {
auto pool = new_scoped_memory_pool();
device_ = load_device();

View File

@ -60,9 +60,11 @@ struct CommandEncoder {
void set_input_array(const array& a, int idx, int64_t offset = 0);
void set_output_array(array& a, int idx, int64_t offset = 0);
void register_output_array(array& a);
void dispatch_threadgroups(MTL::Size grid_dims, MTL::Size group_dims);
void dispatch_threads(MTL::Size grid_dims, MTL::Size group_dims);
void maybeInsertBarrier();
void set_buffer(const MTL::Buffer* buf, int idx, int64_t offset = 0);
void set_compute_pipeline_state(MTL::ComputePipelineState* kernel) {
enc_->setComputePipelineState(kernel);
@ -110,6 +112,8 @@ struct CommandEncoder {
return all_outputs_;
};
void barrier();
private:
MTL::ComputeCommandEncoder* enc_;
bool needs_barrier_{false};

View File

@ -6,6 +6,7 @@
#include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/event.h"
#include "mlx/backend/metal/fence.h"
#include "mlx/distributed/ops.h"
#include "mlx/distributed/primitives.h"
#include "mlx/scheduler.h"
@ -26,23 +27,28 @@ void AllReduce::eval_gpu(
assert(outputs.size() == 1);
auto& in = inputs[0];
Fence f{stream()};
if (in.event().valid()) {
f.update_gpu(in);
}
auto& out = outputs[0];
if (in.is_donatable()) {
out.move_shared_buffer(in);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
f.wait_gpu(out);
auto e = Event(stream());
e.set_value(1);
signal_and_wait(in.event(), e);
auto task = [in = in,
out = out,
e = std::move(e),
f = std::move(f),
reduce_type = reduce_type_,
group = group()]() mutable {
if (in.event().valid()) {
in.event().wait();
f.wait();
}
switch (reduce_type) {
case Sum:
@ -52,7 +58,7 @@ void AllReduce::eval_gpu(
default:
throw std::runtime_error("Only all reduce sum is supported for now");
}
e.signal();
f.update();
};
scheduler::enqueue(detail::communication_stream(), std::move(task));
}
@ -67,17 +73,20 @@ void AllGather::eval_gpu(
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto e = Event(stream());
e.set_value(1);
signal_and_wait(in.event(), e);
Fence f{stream()};
if (in.event().valid()) {
f.update_gpu(in);
}
f.wait_gpu(out);
auto task =
[in = in, out = out, e = std::move(e), group = group()]() mutable {
[in = in, out = out, f = std::move(f), group = group()]() mutable {
if (in.event().valid()) {
in.event().wait();
f.wait();
}
distributed::detail::all_gather(group, in, out);
e.signal();
f.update();
};
scheduler::enqueue(detail::communication_stream(), std::move(task));
}
@ -89,22 +98,28 @@ void Send::eval_gpu(
assert(outputs.size() == 1);
auto& in = inputs[0];
// Encode a signal event for the input
Fence f{stream()};
if (in.event().valid()) {
f.update_gpu(in);
}
auto& out = outputs[0];
move_or_copy(in, out);
// Schedule an async send on the comm stream
auto task = [in = in, out = out, group = group(), dst = dst_]() mutable {
auto task = [in = in,
out = out,
f = std::move(f),
group = group(),
dst = dst_]() mutable {
if (in.event().valid()) {
in.event().wait();
f.wait();
}
distributed::detail::send(group, out, dst);
};
scheduler::enqueue(detail::communication_stream(), std::move(task));
// Encode a signal event for the input
if (in.event().valid()) {
encode_signal(in.event());
}
}
void Recv::eval_gpu(
@ -117,16 +132,14 @@ void Recv::eval_gpu(
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto e = Event(stream());
e.set_value(1);
encode_wait(e);
Fence f{stream()};
f.wait_gpu(out);
// Schedule an async recv on the comm stream
auto task =
[out = out, e = std::move(e), group = group(), src = src_]() mutable {
[out = out, f = std::move(f), group = group(), src = src_]() mutable {
distributed::detail::recv(group, out, src);
e.signal();
f.update();
};
scheduler::enqueue(detail::communication_stream(), std::move(task));
}

145
mlx/backend/metal/fence.cpp Normal file
View File

@ -0,0 +1,145 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/fence.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/event.h"
#include "mlx/backend/metal/metal_impl.h"
#include "mlx/utils.h"
namespace mlx::core {
Fence::Fence(const Stream& stream) : stream_(stream) {
auto d = metal::device(stream.device).mtl_device();
if (!d->supportsFamily(MTL::GPUFamilyMetal3)) {
use_fast_ = false;
} else if (__builtin_available(macOS 15, iOS 18, *)) {
use_fast_ = env::metal_fast_synch();
}
if (!use_fast_) {
// Wraps Metal SharedEvent
auto dtor = [](void* ptr) {
auto p = metal::new_scoped_memory_pool();
static_cast<MTL::SharedEvent*>(ptr)->release();
};
auto p = metal::new_scoped_memory_pool();
fence_ = std::shared_ptr<void>(
metal::device(stream.device).mtl_device()->newSharedEvent(), dtor);
} else {
auto dtor = [](void* buf) {
allocator::free(static_cast<MTL::Buffer*>(buf));
};
auto buf = allocator::malloc_or_wait(sizeof(uint32_t)).ptr();
fence_ = std::shared_ptr<void>(buf, dtor);
cpu_value()[0] = 0;
}
}
void Fence::wait_gpu(array& x) {
gpu_count_++;
auto& d = metal::device(stream_.device);
auto idx = stream_.index;
if (!use_fast_) {
d.end_encoding(idx);
auto command_buffer = d.get_command_buffer(idx);
command_buffer->encodeWait(
static_cast<MTL::Event*>(fence_.get()), gpu_count_);
command_buffer->addCompletedHandler(
[fence_ = fence_](MTL::CommandBuffer* cbuf) {});
return;
}
auto& compute_encoder = d.get_command_encoder(idx);
// Register the output to ensure that no kernels which depends on the
// output starts before this one is done
compute_encoder.register_output_array(x);
auto kernel = d.get_kernel("fence_wait");
MTL::Size kernel_dims = MTL::Size(1, 1, 1);
compute_encoder.set_compute_pipeline_state(kernel);
auto buf = static_cast<MTL::Buffer*>(fence_.get());
compute_encoder.set_buffer(buf, 0);
compute_encoder.set_bytes(gpu_count_, 1);
compute_encoder.dispatch_threads(kernel_dims, kernel_dims);
d.get_command_buffer(idx)->addCompletedHandler(
[fence = fence_](MTL::CommandBuffer* cbuf) {});
}
void Fence::update_gpu(const array& x) {
gpu_count_++;
auto& d = metal::device(stream_.device);
auto idx = stream_.index;
if (!use_fast_) {
d.end_encoding(idx);
auto command_buffer = d.get_command_buffer(idx);
command_buffer->encodeSignalEvent(
static_cast<MTL::Event*>(fence_.get()), gpu_count_);
command_buffer->addCompletedHandler(
[fence_ = fence_](MTL::CommandBuffer* cbuf) {});
return;
}
// Launch input visibility kernel
auto& compute_encoder = d.get_command_encoder(idx);
auto kernel = d.get_kernel("input_coherent");
uint32_t nthreads =
(x.data_size() * x.itemsize() + sizeof(uint32_t) - 1) / sizeof(uint32_t);
MTL::Size group_dims = MTL::Size(1024, 1, 1);
MTL::Size grid_dims = MTL::Size((nthreads + 1024 - 1) / 1024, 1, 1);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(x, 0);
compute_encoder.set_bytes(nthreads, 1);
compute_encoder.dispatch_threadgroups(group_dims, grid_dims);
// Barrier on previous kernel
compute_encoder.barrier();
// Launch value update kernel
kernel = d.get_kernel("fence_update");
MTL::Size kernel_dims = MTL::Size(1, 1, 1);
compute_encoder.set_compute_pipeline_state(kernel);
auto buf = static_cast<MTL::Buffer*>(fence_.get());
compute_encoder.set_buffer(buf, 0);
compute_encoder.set_bytes(gpu_count_, 1);
compute_encoder.dispatch_threads(kernel_dims, kernel_dims);
d.get_command_buffer(idx)->addCompletedHandler(
[fence = fence_](MTL::CommandBuffer* cbuf) {});
}
void Fence::wait() {
cpu_count_++;
if (!use_fast_) {
if (!static_cast<MTL::SharedEvent*>(fence_.get())
->waitUntilSignaledValue(cpu_count_, -1)) {
throw std::runtime_error("[Event::wait] Timed out");
}
return;
}
while (cpu_value()[0] < cpu_count_) {
}
}
void Fence::update() {
cpu_count_++;
if (!use_fast_) {
static_cast<MTL::SharedEvent*>(fence_.get())->setSignaledValue(cpu_count_);
return;
}
cpu_value()[0] = cpu_count_;
}
std::atomic_uint* Fence::cpu_value() {
return static_cast<std::atomic_uint*>(
static_cast<MTL::Buffer*>(fence_.get())->contents());
}
} // namespace mlx::core

40
mlx/backend/metal/fence.h Normal file
View File

@ -0,0 +1,40 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/device.h"
namespace mlx::core {
/* A fence to be used for synchronizing work between the CPU and GPU
*
* Calls to `update_gpu` should be paired with calls to `wait`. This ensures
* that the array passed to `update_gpu` is computed and visible to the CPU
* after the call to `wait` returns.
*
* Calls to `update` should be paired with calls to `wait_gpu`. This ensures
* that the array passed to `wait_gpu` will not be read by the GPU until the CPU
* has called `update`.
*
* The fence supports slow (default) and fast mode. Fast mode requires setting
* the environment variable `MLX_METAL_FAST_SYNCH=1`. Fast mode also requires
* Metal 3.2+ (macOS 15+, iOS 18+).
*/
class Fence {
public:
Fence(const Stream& stream);
void update_gpu(const array& x);
void wait_gpu(array& x);
void wait();
void update();
private:
Stream stream_;
std::shared_ptr<void> fence_;
uint32_t cpu_count_{0};
uint32_t gpu_count_{0};
bool use_fast_;
std::atomic_uint* cpu_value();
};
} // namespace mlx::core

View File

@ -45,6 +45,9 @@ build_kernel(random)
build_kernel(rms_norm)
build_kernel(rope)
build_kernel(scaled_dot_product_attention sdpa_vector.h)
if(MLX_METAL_VERSION GREATER_EQUAL 320)
build_kernel(fence)
endif()
set(STEEL_HEADERS
steel/defines.h

View File

@ -0,0 +1,52 @@
// Copyright © 2024 Apple Inc.
#pragma METAL internals : enable
#ifndef __METAL_MEMORY_SCOPE_SYSTEM__
#define __METAL_MEMORY_SCOPE_SYSTEM__ 3
namespace metal {
constexpr constant metal::thread_scope thread_scope_system =
static_cast<thread_scope>(__METAL_MEMORY_SCOPE_SYSTEM__);
}
#endif
#include <metal_atomic>
[[kernel]] void input_coherent(
volatile coherent(system) device uint* input [[buffer(0)]],
const constant uint& size [[buffer(1)]],
uint index [[thread_position_in_grid]]) {
if (index < size) {
input[index] = input[index];
}
metal::atomic_thread_fence(
metal::mem_flags::mem_device,
metal::memory_order_seq_cst,
metal::thread_scope_system);
}
// single thread kernel to update timestamp
[[kernel]] void fence_update(
volatile coherent(system) device uint* timestamp [[buffer(0)]],
constant uint& value [[buffer(1)]]) {
timestamp[0] = value;
metal::atomic_thread_fence(
metal::mem_flags::mem_device,
metal::memory_order_seq_cst,
metal::thread_scope_system);
}
// single thread kernel to spin wait for timestamp value
[[kernel]] void fence_wait(
volatile coherent(system) device uint* timestamp [[buffer(0)]],
constant uint& value [[buffer(1)]]) {
while (1) {
metal::atomic_thread_fence(
metal::mem_flags::mem_device,
metal::memory_order_seq_cst,
metal::thread_scope_system);
if (timestamp[0] >= value) {
break;
}
}
}

View File

@ -2,6 +2,7 @@
#include <memory>
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/event.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
#include "mlx/scheduler.h"
@ -69,16 +70,13 @@ std::function<void()> make_task(array arr, bool signal) {
if (signal ||
d.get_command_buffer_ops(s.index) >= env::max_ops_per_buffer()) {
d.end_encoding(s.index);
if (signal) {
command_buffer->encodeSignalEvent(
static_cast<MTL::Event*>(arr.event().raw_event().get()),
arr.event().value());
encode_signal(arr.event());
}
d.end_encoding(s.index);
scheduler::notify_new_task(s);
command_buffer->addCompletedHandler(
[s, buffers = std::move(buffers), event = arr.event()](
MTL::CommandBuffer* cbuf) {
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
scheduler::notify_task_completion(s);
check_error(cbuf);
});

View File

@ -126,6 +126,11 @@ inline int max_ops_per_buffer() {
return max_ops_per_buffer_;
}
inline bool metal_fast_synch() {
static bool metal_fast_synch = get_var("MLX_METAL_FAST_SYNCH", 0);
return metal_fast_synch;
}
} // namespace env
} // namespace mlx::core