From c5073fc45268fb0ffdddf8c19a348bf1275ffd6e Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 4 Mar 2025 23:37:15 -0800 Subject: [PATCH] Ensure we only have one copy of the fence --- mlx/backend/metal/distributed.cpp | 32 +++++++++++++++---------------- mlx/backend/metal/fence.cpp | 2 +- mlx/backend/metal/fence.h | 3 +++ 3 files changed, 20 insertions(+), 17 deletions(-) diff --git a/mlx/backend/metal/distributed.cpp b/mlx/backend/metal/distributed.cpp index 9ca727ef4..d52c5a6ef 100644 --- a/mlx/backend/metal/distributed.cpp +++ b/mlx/backend/metal/distributed.cpp @@ -29,10 +29,10 @@ void AllReduce::eval_gpu( auto& in = inputs[0]; - Fence f{stream()}; + auto f = std::make_shared(stream()); if (in.event().valid()) { - f.update_gpu(in); + f->update_gpu(in); } auto& out = outputs[0]; @@ -41,7 +41,7 @@ void AllReduce::eval_gpu( } else { out.set_data(allocator::malloc_or_wait(out.nbytes())); } - f.wait_gpu(out); + f->wait_gpu(out); auto task = [in = in, out = unsafe_weak_copy(out), @@ -49,7 +49,7 @@ void AllReduce::eval_gpu( reduce_type = reduce_type_, group = group()]() mutable { if (in.event().valid()) { - f.wait(); + f->wait(); } switch (reduce_type) { case Sum: @@ -59,7 +59,7 @@ void AllReduce::eval_gpu( default: throw std::runtime_error("Only all reduce sum is supported for now"); } - f.update(); + f->update(); }; scheduler::enqueue(detail::communication_stream(), std::move(task)); } @@ -74,22 +74,22 @@ void AllGather::eval_gpu( out.set_data(allocator::malloc_or_wait(out.nbytes())); - Fence f{stream()}; + auto f = std::make_shared(stream()); if (in.event().valid()) { - f.update_gpu(in); + f->update_gpu(in); } - f.wait_gpu(out); + f->wait_gpu(out); auto task = [in = in, out = unsafe_weak_copy(out), f = std::move(f), group = group()]() mutable { if (in.event().valid()) { - f.wait(); + f->wait(); } distributed::detail::all_gather(group, in, out); - f.update(); + f->update(); }; scheduler::enqueue(detail::communication_stream(), std::move(task)); } @@ -103,9 +103,9 @@ void Send::eval_gpu( auto& in = inputs[0]; // Encode a signal event for the input - Fence f{stream()}; + auto f = std::make_shared(stream()); if (in.event().valid()) { - f.update_gpu(in); + f->update_gpu(in); } auto& out = outputs[0]; @@ -118,7 +118,7 @@ void Send::eval_gpu( group = group(), dst = dst_]() mutable { if (in.event().valid()) { - f.wait(); + f->wait(); } distributed::detail::send(group, out, dst); }; @@ -135,8 +135,8 @@ void Recv::eval_gpu( out.set_data(allocator::malloc_or_wait(out.nbytes())); - Fence f{stream()}; - f.wait_gpu(out); + auto f = std::make_shared(stream()); + f->wait_gpu(out); // Schedule an async recv on the comm stream auto task = [out = unsafe_weak_copy(out), @@ -144,7 +144,7 @@ void Recv::eval_gpu( group = group(), src = src_]() mutable { distributed::detail::recv(group, out, src); - f.update(); + f->update(); }; scheduler::enqueue(detail::communication_stream(), std::move(task)); } diff --git a/mlx/backend/metal/fence.cpp b/mlx/backend/metal/fence.cpp index 949ce7e23..0fb2e7dc7 100644 --- a/mlx/backend/metal/fence.cpp +++ b/mlx/backend/metal/fence.cpp @@ -35,7 +35,7 @@ Fence::Fence(const Stream& stream) : stream_(stream) { } Fence::~Fence() { - if (use_fast_) { + if (fence_ != nullptr && use_fast_) { cpu_value()[0] = INT_MAX; } } diff --git a/mlx/backend/metal/fence.h b/mlx/backend/metal/fence.h index 7d53d469e..5dfba179b 100644 --- a/mlx/backend/metal/fence.h +++ b/mlx/backend/metal/fence.h @@ -20,6 +20,9 @@ namespace mlx::core { */ class Fence { public: + Fence(const Fence&) = delete; + Fence& operator=(const Fence&) = delete; + Fence(const Stream& stream); ~Fence();