mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-11 19:56:40 +08:00
Faster synchronization Fence
primitive (#1773)
* try faster synchronization move event fixes update bench fix fix * non-functioning kernel * try alternative fence * cleanup barrier * get rid of event_fence * update benchmarks * doc string in metal fence
This commit is contained in:
parent
0c259961ac
commit
a4667da1eb
55
benchmarks/python/synchronize_bench.py
Normal file
55
benchmarks/python/synchronize_bench.py
Normal file
@ -0,0 +1,55 @@
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
rank = mx.distributed.init().rank()
|
||||
|
||||
|
||||
def timeit(fn, a):
|
||||
|
||||
# warmup
|
||||
for _ in range(5):
|
||||
mx.eval(fn(a))
|
||||
|
||||
its = 10
|
||||
tic = time.perf_counter()
|
||||
for _ in range(its):
|
||||
mx.eval(fn(a))
|
||||
toc = time.perf_counter()
|
||||
ms = 1000 * (toc - tic) / its
|
||||
return ms
|
||||
|
||||
|
||||
def all_reduce_benchmark():
|
||||
a = mx.ones((5, 5), mx.int32)
|
||||
|
||||
its_per_eval = 100
|
||||
|
||||
def fn(x):
|
||||
for _ in range(its_per_eval):
|
||||
x = mx.distributed.all_sum(x)
|
||||
x = x - 1
|
||||
return x
|
||||
|
||||
ms = timeit(fn, a) / its_per_eval
|
||||
if rank == 0:
|
||||
print(f"All Reduce: time per iteration {ms:.6f} (ms)")
|
||||
|
||||
|
||||
def all_gather_benchmark():
|
||||
a = mx.ones((5, 5), mx.int32)
|
||||
its_per_eval = 100
|
||||
|
||||
def fn(x):
|
||||
for _ in range(its_per_eval):
|
||||
x = mx.distributed.all_gather(x)[0]
|
||||
return x
|
||||
|
||||
ms = timeit(fn, a) / its_per_eval
|
||||
if rank == 0:
|
||||
print(f"All gather: time per iteration {ms:.6f} (ms)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
all_reduce_benchmark()
|
||||
all_gather_benchmark()
|
@ -89,6 +89,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
|
@ -134,6 +134,13 @@ CommandEncoder::~CommandEncoder() {
|
||||
enc_->release();
|
||||
}
|
||||
|
||||
void CommandEncoder::set_buffer(
|
||||
const MTL::Buffer* buf,
|
||||
int idx,
|
||||
int64_t offset /* = 0 */) {
|
||||
enc_->setBuffer(buf, offset, idx);
|
||||
}
|
||||
|
||||
void CommandEncoder::set_input_array(
|
||||
const array& a,
|
||||
int idx,
|
||||
@ -155,6 +162,10 @@ void CommandEncoder::set_output_array(
|
||||
int64_t offset /* = 0 */) {
|
||||
// Add barriers before adding the output to the output set
|
||||
set_input_array(a, idx, offset);
|
||||
register_output_array(a);
|
||||
}
|
||||
|
||||
void CommandEncoder::register_output_array(array& a) {
|
||||
all_outputs_.insert(a.buffer().ptr());
|
||||
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
|
||||
if (concurrent_) {
|
||||
@ -189,6 +200,10 @@ void CommandEncoder::dispatch_threads(
|
||||
enc_->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void CommandEncoder::barrier() {
|
||||
enc_->memoryBarrier(MTL::BarrierScopeBuffers);
|
||||
}
|
||||
|
||||
Device::Device() {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
device_ = load_device();
|
||||
|
@ -60,9 +60,11 @@ struct CommandEncoder {
|
||||
|
||||
void set_input_array(const array& a, int idx, int64_t offset = 0);
|
||||
void set_output_array(array& a, int idx, int64_t offset = 0);
|
||||
void register_output_array(array& a);
|
||||
void dispatch_threadgroups(MTL::Size grid_dims, MTL::Size group_dims);
|
||||
void dispatch_threads(MTL::Size grid_dims, MTL::Size group_dims);
|
||||
void maybeInsertBarrier();
|
||||
void set_buffer(const MTL::Buffer* buf, int idx, int64_t offset = 0);
|
||||
|
||||
void set_compute_pipeline_state(MTL::ComputePipelineState* kernel) {
|
||||
enc_->setComputePipelineState(kernel);
|
||||
@ -110,6 +112,8 @@ struct CommandEncoder {
|
||||
return all_outputs_;
|
||||
};
|
||||
|
||||
void barrier();
|
||||
|
||||
private:
|
||||
MTL::ComputeCommandEncoder* enc_;
|
||||
bool needs_barrier_{false};
|
||||
|
@ -6,6 +6,7 @@
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/event.h"
|
||||
#include "mlx/backend/metal/fence.h"
|
||||
#include "mlx/distributed/ops.h"
|
||||
#include "mlx/distributed/primitives.h"
|
||||
#include "mlx/scheduler.h"
|
||||
@ -26,23 +27,28 @@ void AllReduce::eval_gpu(
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
auto& in = inputs[0];
|
||||
|
||||
Fence f{stream()};
|
||||
|
||||
if (in.event().valid()) {
|
||||
f.update_gpu(in);
|
||||
}
|
||||
|
||||
auto& out = outputs[0];
|
||||
if (in.is_donatable()) {
|
||||
out.move_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
f.wait_gpu(out);
|
||||
|
||||
auto e = Event(stream());
|
||||
e.set_value(1);
|
||||
signal_and_wait(in.event(), e);
|
||||
auto task = [in = in,
|
||||
out = out,
|
||||
e = std::move(e),
|
||||
f = std::move(f),
|
||||
reduce_type = reduce_type_,
|
||||
group = group()]() mutable {
|
||||
if (in.event().valid()) {
|
||||
in.event().wait();
|
||||
f.wait();
|
||||
}
|
||||
switch (reduce_type) {
|
||||
case Sum:
|
||||
@ -52,7 +58,7 @@ void AllReduce::eval_gpu(
|
||||
default:
|
||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
||||
}
|
||||
e.signal();
|
||||
f.update();
|
||||
};
|
||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||
}
|
||||
@ -67,17 +73,20 @@ void AllGather::eval_gpu(
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto e = Event(stream());
|
||||
e.set_value(1);
|
||||
signal_and_wait(in.event(), e);
|
||||
Fence f{stream()};
|
||||
|
||||
if (in.event().valid()) {
|
||||
f.update_gpu(in);
|
||||
}
|
||||
f.wait_gpu(out);
|
||||
|
||||
auto task =
|
||||
[in = in, out = out, e = std::move(e), group = group()]() mutable {
|
||||
[in = in, out = out, f = std::move(f), group = group()]() mutable {
|
||||
if (in.event().valid()) {
|
||||
in.event().wait();
|
||||
f.wait();
|
||||
}
|
||||
distributed::detail::all_gather(group, in, out);
|
||||
e.signal();
|
||||
f.update();
|
||||
};
|
||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||
}
|
||||
@ -89,22 +98,28 @@ void Send::eval_gpu(
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
auto& in = inputs[0];
|
||||
|
||||
// Encode a signal event for the input
|
||||
Fence f{stream()};
|
||||
if (in.event().valid()) {
|
||||
f.update_gpu(in);
|
||||
}
|
||||
|
||||
auto& out = outputs[0];
|
||||
move_or_copy(in, out);
|
||||
|
||||
// Schedule an async send on the comm stream
|
||||
auto task = [in = in, out = out, group = group(), dst = dst_]() mutable {
|
||||
auto task = [in = in,
|
||||
out = out,
|
||||
f = std::move(f),
|
||||
group = group(),
|
||||
dst = dst_]() mutable {
|
||||
if (in.event().valid()) {
|
||||
in.event().wait();
|
||||
f.wait();
|
||||
}
|
||||
distributed::detail::send(group, out, dst);
|
||||
};
|
||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||
|
||||
// Encode a signal event for the input
|
||||
if (in.event().valid()) {
|
||||
encode_signal(in.event());
|
||||
}
|
||||
}
|
||||
|
||||
void Recv::eval_gpu(
|
||||
@ -117,16 +132,14 @@ void Recv::eval_gpu(
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto e = Event(stream());
|
||||
e.set_value(1);
|
||||
|
||||
encode_wait(e);
|
||||
Fence f{stream()};
|
||||
f.wait_gpu(out);
|
||||
|
||||
// Schedule an async recv on the comm stream
|
||||
auto task =
|
||||
[out = out, e = std::move(e), group = group(), src = src_]() mutable {
|
||||
[out = out, f = std::move(f), group = group(), src = src_]() mutable {
|
||||
distributed::detail::recv(group, out, src);
|
||||
e.signal();
|
||||
f.update();
|
||||
};
|
||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||
}
|
||||
|
145
mlx/backend/metal/fence.cpp
Normal file
145
mlx/backend/metal/fence.cpp
Normal file
@ -0,0 +1,145 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/fence.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/event.h"
|
||||
#include "mlx/backend/metal/metal_impl.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
Fence::Fence(const Stream& stream) : stream_(stream) {
|
||||
auto d = metal::device(stream.device).mtl_device();
|
||||
if (!d->supportsFamily(MTL::GPUFamilyMetal3)) {
|
||||
use_fast_ = false;
|
||||
} else if (__builtin_available(macOS 15, iOS 18, *)) {
|
||||
use_fast_ = env::metal_fast_synch();
|
||||
}
|
||||
if (!use_fast_) {
|
||||
// Wraps Metal SharedEvent
|
||||
auto dtor = [](void* ptr) {
|
||||
auto p = metal::new_scoped_memory_pool();
|
||||
static_cast<MTL::SharedEvent*>(ptr)->release();
|
||||
};
|
||||
auto p = metal::new_scoped_memory_pool();
|
||||
fence_ = std::shared_ptr<void>(
|
||||
metal::device(stream.device).mtl_device()->newSharedEvent(), dtor);
|
||||
} else {
|
||||
auto dtor = [](void* buf) {
|
||||
allocator::free(static_cast<MTL::Buffer*>(buf));
|
||||
};
|
||||
auto buf = allocator::malloc_or_wait(sizeof(uint32_t)).ptr();
|
||||
fence_ = std::shared_ptr<void>(buf, dtor);
|
||||
cpu_value()[0] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
void Fence::wait_gpu(array& x) {
|
||||
gpu_count_++;
|
||||
auto& d = metal::device(stream_.device);
|
||||
auto idx = stream_.index;
|
||||
|
||||
if (!use_fast_) {
|
||||
d.end_encoding(idx);
|
||||
auto command_buffer = d.get_command_buffer(idx);
|
||||
command_buffer->encodeWait(
|
||||
static_cast<MTL::Event*>(fence_.get()), gpu_count_);
|
||||
command_buffer->addCompletedHandler(
|
||||
[fence_ = fence_](MTL::CommandBuffer* cbuf) {});
|
||||
return;
|
||||
}
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(idx);
|
||||
|
||||
// Register the output to ensure that no kernels which depends on the
|
||||
// output starts before this one is done
|
||||
compute_encoder.register_output_array(x);
|
||||
|
||||
auto kernel = d.get_kernel("fence_wait");
|
||||
MTL::Size kernel_dims = MTL::Size(1, 1, 1);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
auto buf = static_cast<MTL::Buffer*>(fence_.get());
|
||||
compute_encoder.set_buffer(buf, 0);
|
||||
compute_encoder.set_bytes(gpu_count_, 1);
|
||||
compute_encoder.dispatch_threads(kernel_dims, kernel_dims);
|
||||
|
||||
d.get_command_buffer(idx)->addCompletedHandler(
|
||||
[fence = fence_](MTL::CommandBuffer* cbuf) {});
|
||||
}
|
||||
|
||||
void Fence::update_gpu(const array& x) {
|
||||
gpu_count_++;
|
||||
auto& d = metal::device(stream_.device);
|
||||
auto idx = stream_.index;
|
||||
|
||||
if (!use_fast_) {
|
||||
d.end_encoding(idx);
|
||||
auto command_buffer = d.get_command_buffer(idx);
|
||||
command_buffer->encodeSignalEvent(
|
||||
static_cast<MTL::Event*>(fence_.get()), gpu_count_);
|
||||
command_buffer->addCompletedHandler(
|
||||
[fence_ = fence_](MTL::CommandBuffer* cbuf) {});
|
||||
return;
|
||||
}
|
||||
|
||||
// Launch input visibility kernel
|
||||
auto& compute_encoder = d.get_command_encoder(idx);
|
||||
auto kernel = d.get_kernel("input_coherent");
|
||||
uint32_t nthreads =
|
||||
(x.data_size() * x.itemsize() + sizeof(uint32_t) - 1) / sizeof(uint32_t);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(1024, 1, 1);
|
||||
MTL::Size grid_dims = MTL::Size((nthreads + 1024 - 1) / 1024, 1, 1);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(x, 0);
|
||||
compute_encoder.set_bytes(nthreads, 1);
|
||||
compute_encoder.dispatch_threadgroups(group_dims, grid_dims);
|
||||
|
||||
// Barrier on previous kernel
|
||||
compute_encoder.barrier();
|
||||
|
||||
// Launch value update kernel
|
||||
kernel = d.get_kernel("fence_update");
|
||||
MTL::Size kernel_dims = MTL::Size(1, 1, 1);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
auto buf = static_cast<MTL::Buffer*>(fence_.get());
|
||||
compute_encoder.set_buffer(buf, 0);
|
||||
compute_encoder.set_bytes(gpu_count_, 1);
|
||||
compute_encoder.dispatch_threads(kernel_dims, kernel_dims);
|
||||
|
||||
d.get_command_buffer(idx)->addCompletedHandler(
|
||||
[fence = fence_](MTL::CommandBuffer* cbuf) {});
|
||||
}
|
||||
|
||||
void Fence::wait() {
|
||||
cpu_count_++;
|
||||
if (!use_fast_) {
|
||||
if (!static_cast<MTL::SharedEvent*>(fence_.get())
|
||||
->waitUntilSignaledValue(cpu_count_, -1)) {
|
||||
throw std::runtime_error("[Event::wait] Timed out");
|
||||
}
|
||||
return;
|
||||
}
|
||||
while (cpu_value()[0] < cpu_count_) {
|
||||
}
|
||||
}
|
||||
|
||||
void Fence::update() {
|
||||
cpu_count_++;
|
||||
|
||||
if (!use_fast_) {
|
||||
static_cast<MTL::SharedEvent*>(fence_.get())->setSignaledValue(cpu_count_);
|
||||
return;
|
||||
}
|
||||
|
||||
cpu_value()[0] = cpu_count_;
|
||||
}
|
||||
|
||||
std::atomic_uint* Fence::cpu_value() {
|
||||
return static_cast<std::atomic_uint*>(
|
||||
static_cast<MTL::Buffer*>(fence_.get())->contents());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
40
mlx/backend/metal/fence.h
Normal file
40
mlx/backend/metal/fence.h
Normal file
@ -0,0 +1,40 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/device.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
/* A fence to be used for synchronizing work between the CPU and GPU
|
||||
*
|
||||
* Calls to `update_gpu` should be paired with calls to `wait`. This ensures
|
||||
* that the array passed to `update_gpu` is computed and visible to the CPU
|
||||
* after the call to `wait` returns.
|
||||
*
|
||||
* Calls to `update` should be paired with calls to `wait_gpu`. This ensures
|
||||
* that the array passed to `wait_gpu` will not be read by the GPU until the CPU
|
||||
* has called `update`.
|
||||
*
|
||||
* The fence supports slow (default) and fast mode. Fast mode requires setting
|
||||
* the environment variable `MLX_METAL_FAST_SYNCH=1`. Fast mode also requires
|
||||
* Metal 3.2+ (macOS 15+, iOS 18+).
|
||||
*/
|
||||
class Fence {
|
||||
public:
|
||||
Fence(const Stream& stream);
|
||||
|
||||
void update_gpu(const array& x);
|
||||
void wait_gpu(array& x);
|
||||
|
||||
void wait();
|
||||
void update();
|
||||
|
||||
private:
|
||||
Stream stream_;
|
||||
std::shared_ptr<void> fence_;
|
||||
uint32_t cpu_count_{0};
|
||||
uint32_t gpu_count_{0};
|
||||
bool use_fast_;
|
||||
std::atomic_uint* cpu_value();
|
||||
};
|
||||
|
||||
} // namespace mlx::core
|
@ -45,6 +45,9 @@ build_kernel(random)
|
||||
build_kernel(rms_norm)
|
||||
build_kernel(rope)
|
||||
build_kernel(scaled_dot_product_attention sdpa_vector.h)
|
||||
if(MLX_METAL_VERSION GREATER_EQUAL 320)
|
||||
build_kernel(fence)
|
||||
endif()
|
||||
|
||||
set(STEEL_HEADERS
|
||||
steel/defines.h
|
||||
|
52
mlx/backend/metal/kernels/fence.metal
Normal file
52
mlx/backend/metal/kernels/fence.metal
Normal file
@ -0,0 +1,52 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma METAL internals : enable
|
||||
|
||||
#ifndef __METAL_MEMORY_SCOPE_SYSTEM__
|
||||
#define __METAL_MEMORY_SCOPE_SYSTEM__ 3
|
||||
namespace metal {
|
||||
constexpr constant metal::thread_scope thread_scope_system =
|
||||
static_cast<thread_scope>(__METAL_MEMORY_SCOPE_SYSTEM__);
|
||||
}
|
||||
#endif
|
||||
|
||||
#include <metal_atomic>
|
||||
|
||||
[[kernel]] void input_coherent(
|
||||
volatile coherent(system) device uint* input [[buffer(0)]],
|
||||
const constant uint& size [[buffer(1)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
if (index < size) {
|
||||
input[index] = input[index];
|
||||
}
|
||||
metal::atomic_thread_fence(
|
||||
metal::mem_flags::mem_device,
|
||||
metal::memory_order_seq_cst,
|
||||
metal::thread_scope_system);
|
||||
}
|
||||
|
||||
// single thread kernel to update timestamp
|
||||
[[kernel]] void fence_update(
|
||||
volatile coherent(system) device uint* timestamp [[buffer(0)]],
|
||||
constant uint& value [[buffer(1)]]) {
|
||||
timestamp[0] = value;
|
||||
metal::atomic_thread_fence(
|
||||
metal::mem_flags::mem_device,
|
||||
metal::memory_order_seq_cst,
|
||||
metal::thread_scope_system);
|
||||
}
|
||||
|
||||
// single thread kernel to spin wait for timestamp value
|
||||
[[kernel]] void fence_wait(
|
||||
volatile coherent(system) device uint* timestamp [[buffer(0)]],
|
||||
constant uint& value [[buffer(1)]]) {
|
||||
while (1) {
|
||||
metal::atomic_thread_fence(
|
||||
metal::mem_flags::mem_device,
|
||||
metal::memory_order_seq_cst,
|
||||
metal::thread_scope_system);
|
||||
if (timestamp[0] >= value) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
@ -2,6 +2,7 @@
|
||||
#include <memory>
|
||||
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/event.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/scheduler.h"
|
||||
@ -69,16 +70,13 @@ std::function<void()> make_task(array arr, bool signal) {
|
||||
|
||||
if (signal ||
|
||||
d.get_command_buffer_ops(s.index) >= env::max_ops_per_buffer()) {
|
||||
d.end_encoding(s.index);
|
||||
if (signal) {
|
||||
command_buffer->encodeSignalEvent(
|
||||
static_cast<MTL::Event*>(arr.event().raw_event().get()),
|
||||
arr.event().value());
|
||||
encode_signal(arr.event());
|
||||
}
|
||||
d.end_encoding(s.index);
|
||||
scheduler::notify_new_task(s);
|
||||
command_buffer->addCompletedHandler(
|
||||
[s, buffers = std::move(buffers), event = arr.event()](
|
||||
MTL::CommandBuffer* cbuf) {
|
||||
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
|
||||
scheduler::notify_task_completion(s);
|
||||
check_error(cbuf);
|
||||
});
|
||||
|
@ -126,6 +126,11 @@ inline int max_ops_per_buffer() {
|
||||
return max_ops_per_buffer_;
|
||||
}
|
||||
|
||||
inline bool metal_fast_synch() {
|
||||
static bool metal_fast_synch = get_var("MLX_METAL_FAST_SYNCH", 0);
|
||||
return metal_fast_synch;
|
||||
}
|
||||
|
||||
} // namespace env
|
||||
|
||||
} // namespace mlx::core
|
||||
|
Loading…
Reference in New Issue
Block a user