diff --git a/benchmarks/python/synchronize_bench.py b/benchmarks/python/synchronize_bench.py new file mode 100644 index 0000000000..db006d70bf --- /dev/null +++ b/benchmarks/python/synchronize_bench.py @@ -0,0 +1,55 @@ +import time + +import mlx.core as mx + +rank = mx.distributed.init().rank() + + +def timeit(fn, a): + + # warmup + for _ in range(5): + mx.eval(fn(a)) + + its = 10 + tic = time.perf_counter() + for _ in range(its): + mx.eval(fn(a)) + toc = time.perf_counter() + ms = 1000 * (toc - tic) / its + return ms + + +def all_reduce_benchmark(): + a = mx.ones((5, 5), mx.int32) + + its_per_eval = 100 + + def fn(x): + for _ in range(its_per_eval): + x = mx.distributed.all_sum(x) + x = x - 1 + return x + + ms = timeit(fn, a) / its_per_eval + if rank == 0: + print(f"All Reduce: time per iteration {ms:.6f} (ms)") + + +def all_gather_benchmark(): + a = mx.ones((5, 5), mx.int32) + its_per_eval = 100 + + def fn(x): + for _ in range(its_per_eval): + x = mx.distributed.all_gather(x)[0] + return x + + ms = timeit(fn, a) / its_per_eval + if rank == 0: + print(f"All gather: time per iteration {ms:.6f} (ms)") + + +if __name__ == "__main__": + all_reduce_benchmark() + all_gather_benchmark() diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index 26f6f11381..b932a5654e 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -89,6 +89,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/event.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index b06c35c61a..9379733b68 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -134,6 +134,13 @@ CommandEncoder::~CommandEncoder() { enc_->release(); } +void CommandEncoder::set_buffer( + const MTL::Buffer* buf, + int idx, + int64_t offset /* = 0 */) { + enc_->setBuffer(buf, offset, idx); +} + void CommandEncoder::set_input_array( const array& a, int idx, @@ -155,6 +162,10 @@ void CommandEncoder::set_output_array( int64_t offset /* = 0 */) { // Add barriers before adding the output to the output set set_input_array(a, idx, offset); + register_output_array(a); +} + +void CommandEncoder::register_output_array(array& a) { all_outputs_.insert(a.buffer().ptr()); auto buf = static_cast(a.buffer().ptr()); if (concurrent_) { @@ -189,6 +200,10 @@ void CommandEncoder::dispatch_threads( enc_->dispatchThreads(grid_dims, group_dims); } +void CommandEncoder::barrier() { + enc_->memoryBarrier(MTL::BarrierScopeBuffers); +} + Device::Device() { auto pool = new_scoped_memory_pool(); device_ = load_device(); diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index 09397dc362..1c58cdd3d0 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -60,9 +60,11 @@ struct CommandEncoder { void set_input_array(const array& a, int idx, int64_t offset = 0); void set_output_array(array& a, int idx, int64_t offset = 0); + void register_output_array(array& a); void dispatch_threadgroups(MTL::Size grid_dims, MTL::Size group_dims); void dispatch_threads(MTL::Size grid_dims, MTL::Size group_dims); void maybeInsertBarrier(); + void set_buffer(const MTL::Buffer* buf, int idx, int64_t offset = 0); void set_compute_pipeline_state(MTL::ComputePipelineState* kernel) { enc_->setComputePipelineState(kernel); @@ -110,6 +112,8 @@ struct CommandEncoder { return all_outputs_; }; + void barrier(); + private: MTL::ComputeCommandEncoder* enc_; bool needs_barrier_{false}; diff --git a/mlx/backend/metal/distributed.cpp b/mlx/backend/metal/distributed.cpp index be4cc18480..6f36e94847 100644 --- a/mlx/backend/metal/distributed.cpp +++ b/mlx/backend/metal/distributed.cpp @@ -6,6 +6,7 @@ #include "mlx/backend/common/utils.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/event.h" +#include "mlx/backend/metal/fence.h" #include "mlx/distributed/ops.h" #include "mlx/distributed/primitives.h" #include "mlx/scheduler.h" @@ -26,23 +27,28 @@ void AllReduce::eval_gpu( assert(outputs.size() == 1); auto& in = inputs[0]; + + Fence f{stream()}; + + if (in.event().valid()) { + f.update_gpu(in); + } + auto& out = outputs[0]; if (in.is_donatable()) { out.move_shared_buffer(in); } else { out.set_data(allocator::malloc_or_wait(out.nbytes())); } + f.wait_gpu(out); - auto e = Event(stream()); - e.set_value(1); - signal_and_wait(in.event(), e); auto task = [in = in, out = out, - e = std::move(e), + f = std::move(f), reduce_type = reduce_type_, group = group()]() mutable { if (in.event().valid()) { - in.event().wait(); + f.wait(); } switch (reduce_type) { case Sum: @@ -52,7 +58,7 @@ void AllReduce::eval_gpu( default: throw std::runtime_error("Only all reduce sum is supported for now"); } - e.signal(); + f.update(); }; scheduler::enqueue(detail::communication_stream(), std::move(task)); } @@ -67,17 +73,20 @@ void AllGather::eval_gpu( out.set_data(allocator::malloc_or_wait(out.nbytes())); - auto e = Event(stream()); - e.set_value(1); - signal_and_wait(in.event(), e); + Fence f{stream()}; + + if (in.event().valid()) { + f.update_gpu(in); + } + f.wait_gpu(out); auto task = - [in = in, out = out, e = std::move(e), group = group()]() mutable { + [in = in, out = out, f = std::move(f), group = group()]() mutable { if (in.event().valid()) { - in.event().wait(); + f.wait(); } distributed::detail::all_gather(group, in, out); - e.signal(); + f.update(); }; scheduler::enqueue(detail::communication_stream(), std::move(task)); } @@ -89,22 +98,28 @@ void Send::eval_gpu( assert(outputs.size() == 1); auto& in = inputs[0]; + + // Encode a signal event for the input + Fence f{stream()}; + if (in.event().valid()) { + f.update_gpu(in); + } + auto& out = outputs[0]; move_or_copy(in, out); // Schedule an async send on the comm stream - auto task = [in = in, out = out, group = group(), dst = dst_]() mutable { + auto task = [in = in, + out = out, + f = std::move(f), + group = group(), + dst = dst_]() mutable { if (in.event().valid()) { - in.event().wait(); + f.wait(); } distributed::detail::send(group, out, dst); }; scheduler::enqueue(detail::communication_stream(), std::move(task)); - - // Encode a signal event for the input - if (in.event().valid()) { - encode_signal(in.event()); - } } void Recv::eval_gpu( @@ -117,16 +132,14 @@ void Recv::eval_gpu( out.set_data(allocator::malloc_or_wait(out.nbytes())); - auto e = Event(stream()); - e.set_value(1); - - encode_wait(e); + Fence f{stream()}; + f.wait_gpu(out); // Schedule an async recv on the comm stream auto task = - [out = out, e = std::move(e), group = group(), src = src_]() mutable { + [out = out, f = std::move(f), group = group(), src = src_]() mutable { distributed::detail::recv(group, out, src); - e.signal(); + f.update(); }; scheduler::enqueue(detail::communication_stream(), std::move(task)); } diff --git a/mlx/backend/metal/fence.cpp b/mlx/backend/metal/fence.cpp new file mode 100644 index 0000000000..54ce8ea8d3 --- /dev/null +++ b/mlx/backend/metal/fence.cpp @@ -0,0 +1,145 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/backend/metal/fence.h" +#include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/event.h" +#include "mlx/backend/metal/metal_impl.h" +#include "mlx/utils.h" + +namespace mlx::core { + +Fence::Fence(const Stream& stream) : stream_(stream) { + auto d = metal::device(stream.device).mtl_device(); + if (!d->supportsFamily(MTL::GPUFamilyMetal3)) { + use_fast_ = false; + } else if (__builtin_available(macOS 15, iOS 18, *)) { + use_fast_ = env::metal_fast_synch(); + } + if (!use_fast_) { + // Wraps Metal SharedEvent + auto dtor = [](void* ptr) { + auto p = metal::new_scoped_memory_pool(); + static_cast(ptr)->release(); + }; + auto p = metal::new_scoped_memory_pool(); + fence_ = std::shared_ptr( + metal::device(stream.device).mtl_device()->newSharedEvent(), dtor); + } else { + auto dtor = [](void* buf) { + allocator::free(static_cast(buf)); + }; + auto buf = allocator::malloc_or_wait(sizeof(uint32_t)).ptr(); + fence_ = std::shared_ptr(buf, dtor); + cpu_value()[0] = 0; + } +} + +void Fence::wait_gpu(array& x) { + gpu_count_++; + auto& d = metal::device(stream_.device); + auto idx = stream_.index; + + if (!use_fast_) { + d.end_encoding(idx); + auto command_buffer = d.get_command_buffer(idx); + command_buffer->encodeWait( + static_cast(fence_.get()), gpu_count_); + command_buffer->addCompletedHandler( + [fence_ = fence_](MTL::CommandBuffer* cbuf) {}); + return; + } + + auto& compute_encoder = d.get_command_encoder(idx); + + // Register the output to ensure that no kernels which depends on the + // output starts before this one is done + compute_encoder.register_output_array(x); + + auto kernel = d.get_kernel("fence_wait"); + MTL::Size kernel_dims = MTL::Size(1, 1, 1); + compute_encoder.set_compute_pipeline_state(kernel); + + auto buf = static_cast(fence_.get()); + compute_encoder.set_buffer(buf, 0); + compute_encoder.set_bytes(gpu_count_, 1); + compute_encoder.dispatch_threads(kernel_dims, kernel_dims); + + d.get_command_buffer(idx)->addCompletedHandler( + [fence = fence_](MTL::CommandBuffer* cbuf) {}); +} + +void Fence::update_gpu(const array& x) { + gpu_count_++; + auto& d = metal::device(stream_.device); + auto idx = stream_.index; + + if (!use_fast_) { + d.end_encoding(idx); + auto command_buffer = d.get_command_buffer(idx); + command_buffer->encodeSignalEvent( + static_cast(fence_.get()), gpu_count_); + command_buffer->addCompletedHandler( + [fence_ = fence_](MTL::CommandBuffer* cbuf) {}); + return; + } + + // Launch input visibility kernel + auto& compute_encoder = d.get_command_encoder(idx); + auto kernel = d.get_kernel("input_coherent"); + uint32_t nthreads = + (x.data_size() * x.itemsize() + sizeof(uint32_t) - 1) / sizeof(uint32_t); + + MTL::Size group_dims = MTL::Size(1024, 1, 1); + MTL::Size grid_dims = MTL::Size((nthreads + 1024 - 1) / 1024, 1, 1); + compute_encoder.set_compute_pipeline_state(kernel); + compute_encoder.set_input_array(x, 0); + compute_encoder.set_bytes(nthreads, 1); + compute_encoder.dispatch_threadgroups(group_dims, grid_dims); + + // Barrier on previous kernel + compute_encoder.barrier(); + + // Launch value update kernel + kernel = d.get_kernel("fence_update"); + MTL::Size kernel_dims = MTL::Size(1, 1, 1); + compute_encoder.set_compute_pipeline_state(kernel); + + auto buf = static_cast(fence_.get()); + compute_encoder.set_buffer(buf, 0); + compute_encoder.set_bytes(gpu_count_, 1); + compute_encoder.dispatch_threads(kernel_dims, kernel_dims); + + d.get_command_buffer(idx)->addCompletedHandler( + [fence = fence_](MTL::CommandBuffer* cbuf) {}); +} + +void Fence::wait() { + cpu_count_++; + if (!use_fast_) { + if (!static_cast(fence_.get()) + ->waitUntilSignaledValue(cpu_count_, -1)) { + throw std::runtime_error("[Event::wait] Timed out"); + } + return; + } + while (cpu_value()[0] < cpu_count_) { + } +} + +void Fence::update() { + cpu_count_++; + + if (!use_fast_) { + static_cast(fence_.get())->setSignaledValue(cpu_count_); + return; + } + + cpu_value()[0] = cpu_count_; +} + +std::atomic_uint* Fence::cpu_value() { + return static_cast( + static_cast(fence_.get())->contents()); +} + +} // namespace mlx::core diff --git a/mlx/backend/metal/fence.h b/mlx/backend/metal/fence.h new file mode 100644 index 0000000000..28e04e6430 --- /dev/null +++ b/mlx/backend/metal/fence.h @@ -0,0 +1,40 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/backend/metal/device.h" + +namespace mlx::core { + +/* A fence to be used for synchronizing work between the CPU and GPU + * + * Calls to `update_gpu` should be paired with calls to `wait`. This ensures + * that the array passed to `update_gpu` is computed and visible to the CPU + * after the call to `wait` returns. + * + * Calls to `update` should be paired with calls to `wait_gpu`. This ensures + * that the array passed to `wait_gpu` will not be read by the GPU until the CPU + * has called `update`. + * + * The fence supports slow (default) and fast mode. Fast mode requires setting + * the environment variable `MLX_METAL_FAST_SYNCH=1`. Fast mode also requires + * Metal 3.2+ (macOS 15+, iOS 18+). + */ +class Fence { + public: + Fence(const Stream& stream); + + void update_gpu(const array& x); + void wait_gpu(array& x); + + void wait(); + void update(); + + private: + Stream stream_; + std::shared_ptr fence_; + uint32_t cpu_count_{0}; + uint32_t gpu_count_{0}; + bool use_fast_; + std::atomic_uint* cpu_value(); +}; + +} // namespace mlx::core diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 9c0856425a..ff9c7fa30d 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -45,6 +45,9 @@ build_kernel(random) build_kernel(rms_norm) build_kernel(rope) build_kernel(scaled_dot_product_attention sdpa_vector.h) +if(MLX_METAL_VERSION GREATER_EQUAL 320) + build_kernel(fence) +endif() set(STEEL_HEADERS steel/defines.h diff --git a/mlx/backend/metal/kernels/fence.metal b/mlx/backend/metal/kernels/fence.metal new file mode 100644 index 0000000000..f736b33b07 --- /dev/null +++ b/mlx/backend/metal/kernels/fence.metal @@ -0,0 +1,52 @@ +// Copyright © 2024 Apple Inc. + +#pragma METAL internals : enable + +#ifndef __METAL_MEMORY_SCOPE_SYSTEM__ +#define __METAL_MEMORY_SCOPE_SYSTEM__ 3 +namespace metal { +constexpr constant metal::thread_scope thread_scope_system = + static_cast(__METAL_MEMORY_SCOPE_SYSTEM__); +} +#endif + +#include + +[[kernel]] void input_coherent( + volatile coherent(system) device uint* input [[buffer(0)]], + const constant uint& size [[buffer(1)]], + uint index [[thread_position_in_grid]]) { + if (index < size) { + input[index] = input[index]; + } + metal::atomic_thread_fence( + metal::mem_flags::mem_device, + metal::memory_order_seq_cst, + metal::thread_scope_system); +} + +// single thread kernel to update timestamp +[[kernel]] void fence_update( + volatile coherent(system) device uint* timestamp [[buffer(0)]], + constant uint& value [[buffer(1)]]) { + timestamp[0] = value; + metal::atomic_thread_fence( + metal::mem_flags::mem_device, + metal::memory_order_seq_cst, + metal::thread_scope_system); +} + +// single thread kernel to spin wait for timestamp value +[[kernel]] void fence_wait( + volatile coherent(system) device uint* timestamp [[buffer(0)]], + constant uint& value [[buffer(1)]]) { + while (1) { + metal::atomic_thread_fence( + metal::mem_flags::mem_device, + metal::memory_order_seq_cst, + metal::thread_scope_system); + if (timestamp[0] >= value) { + break; + } + } +} diff --git a/mlx/backend/metal/metal.cpp b/mlx/backend/metal/metal.cpp index cd4d67a17e..a8bc716c13 100644 --- a/mlx/backend/metal/metal.cpp +++ b/mlx/backend/metal/metal.cpp @@ -2,6 +2,7 @@ #include #include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/event.h" #include "mlx/backend/metal/utils.h" #include "mlx/primitives.h" #include "mlx/scheduler.h" @@ -69,16 +70,13 @@ std::function make_task(array arr, bool signal) { if (signal || d.get_command_buffer_ops(s.index) >= env::max_ops_per_buffer()) { - d.end_encoding(s.index); if (signal) { - command_buffer->encodeSignalEvent( - static_cast(arr.event().raw_event().get()), - arr.event().value()); + encode_signal(arr.event()); } + d.end_encoding(s.index); scheduler::notify_new_task(s); command_buffer->addCompletedHandler( - [s, buffers = std::move(buffers), event = arr.event()]( - MTL::CommandBuffer* cbuf) { + [s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) { scheduler::notify_task_completion(s); check_error(cbuf); }); diff --git a/mlx/utils.h b/mlx/utils.h index aca6ccd6f7..94b0974d79 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -126,6 +126,11 @@ inline int max_ops_per_buffer() { return max_ops_per_buffer_; } +inline bool metal_fast_synch() { + static bool metal_fast_synch = get_var("MLX_METAL_FAST_SYNCH", 0); + return metal_fast_synch; +} + } // namespace env } // namespace mlx::core