Ensure we only have one copy of the fence

This commit is contained in:
Angelos Katharopoulos 2025-03-04 23:37:15 -08:00
parent f4a5959055
commit c5073fc452
3 changed files with 20 additions and 17 deletions

View File

@ -29,10 +29,10 @@ void AllReduce::eval_gpu(
auto& in = inputs[0]; auto& in = inputs[0];
Fence f{stream()}; auto f = std::make_shared<Fence>(stream());
if (in.event().valid()) { if (in.event().valid()) {
f.update_gpu(in); f->update_gpu(in);
} }
auto& out = outputs[0]; auto& out = outputs[0];
@ -41,7 +41,7 @@ void AllReduce::eval_gpu(
} else { } else {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));
} }
f.wait_gpu(out); f->wait_gpu(out);
auto task = [in = in, auto task = [in = in,
out = unsafe_weak_copy(out), out = unsafe_weak_copy(out),
@ -49,7 +49,7 @@ void AllReduce::eval_gpu(
reduce_type = reduce_type_, reduce_type = reduce_type_,
group = group()]() mutable { group = group()]() mutable {
if (in.event().valid()) { if (in.event().valid()) {
f.wait(); f->wait();
} }
switch (reduce_type) { switch (reduce_type) {
case Sum: case Sum:
@ -59,7 +59,7 @@ void AllReduce::eval_gpu(
default: default:
throw std::runtime_error("Only all reduce sum is supported for now"); throw std::runtime_error("Only all reduce sum is supported for now");
} }
f.update(); f->update();
}; };
scheduler::enqueue(detail::communication_stream(), std::move(task)); scheduler::enqueue(detail::communication_stream(), std::move(task));
} }
@ -74,22 +74,22 @@ void AllGather::eval_gpu(
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));
Fence f{stream()}; auto f = std::make_shared<Fence>(stream());
if (in.event().valid()) { if (in.event().valid()) {
f.update_gpu(in); f->update_gpu(in);
} }
f.wait_gpu(out); f->wait_gpu(out);
auto task = [in = in, auto task = [in = in,
out = unsafe_weak_copy(out), out = unsafe_weak_copy(out),
f = std::move(f), f = std::move(f),
group = group()]() mutable { group = group()]() mutable {
if (in.event().valid()) { if (in.event().valid()) {
f.wait(); f->wait();
} }
distributed::detail::all_gather(group, in, out); distributed::detail::all_gather(group, in, out);
f.update(); f->update();
}; };
scheduler::enqueue(detail::communication_stream(), std::move(task)); scheduler::enqueue(detail::communication_stream(), std::move(task));
} }
@ -103,9 +103,9 @@ void Send::eval_gpu(
auto& in = inputs[0]; auto& in = inputs[0];
// Encode a signal event for the input // Encode a signal event for the input
Fence f{stream()}; auto f = std::make_shared<Fence>(stream());
if (in.event().valid()) { if (in.event().valid()) {
f.update_gpu(in); f->update_gpu(in);
} }
auto& out = outputs[0]; auto& out = outputs[0];
@ -118,7 +118,7 @@ void Send::eval_gpu(
group = group(), group = group(),
dst = dst_]() mutable { dst = dst_]() mutable {
if (in.event().valid()) { if (in.event().valid()) {
f.wait(); f->wait();
} }
distributed::detail::send(group, out, dst); distributed::detail::send(group, out, dst);
}; };
@ -135,8 +135,8 @@ void Recv::eval_gpu(
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));
Fence f{stream()}; auto f = std::make_shared<Fence>(stream());
f.wait_gpu(out); f->wait_gpu(out);
// Schedule an async recv on the comm stream // Schedule an async recv on the comm stream
auto task = [out = unsafe_weak_copy(out), auto task = [out = unsafe_weak_copy(out),
@ -144,7 +144,7 @@ void Recv::eval_gpu(
group = group(), group = group(),
src = src_]() mutable { src = src_]() mutable {
distributed::detail::recv(group, out, src); distributed::detail::recv(group, out, src);
f.update(); f->update();
}; };
scheduler::enqueue(detail::communication_stream(), std::move(task)); scheduler::enqueue(detail::communication_stream(), std::move(task));
} }

View File

@ -35,7 +35,7 @@ Fence::Fence(const Stream& stream) : stream_(stream) {
} }
Fence::~Fence() { Fence::~Fence() {
if (use_fast_) { if (fence_ != nullptr && use_fast_) {
cpu_value()[0] = INT_MAX; cpu_value()[0] = INT_MAX;
} }
} }

View File

@ -20,6 +20,9 @@ namespace mlx::core {
*/ */
class Fence { class Fence {
public: public:
Fence(const Fence&) = delete;
Fence& operator=(const Fence&) = delete;
Fence(const Stream& stream); Fence(const Stream& stream);
~Fence(); ~Fence();