Ensure we only have one copy of the fence

This commit is contained in:
Angelos Katharopoulos 2025-03-04 23:37:15 -08:00
parent f4a5959055
commit c5073fc452
3 changed files with 20 additions and 17 deletions

View File

@ -29,10 +29,10 @@ void AllReduce::eval_gpu(
auto& in = inputs[0];
Fence f{stream()};
auto f = std::make_shared<Fence>(stream());
if (in.event().valid()) {
f.update_gpu(in);
f->update_gpu(in);
}
auto& out = outputs[0];
@ -41,7 +41,7 @@ void AllReduce::eval_gpu(
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
f.wait_gpu(out);
f->wait_gpu(out);
auto task = [in = in,
out = unsafe_weak_copy(out),
@ -49,7 +49,7 @@ void AllReduce::eval_gpu(
reduce_type = reduce_type_,
group = group()]() mutable {
if (in.event().valid()) {
f.wait();
f->wait();
}
switch (reduce_type) {
case Sum:
@ -59,7 +59,7 @@ void AllReduce::eval_gpu(
default:
throw std::runtime_error("Only all reduce sum is supported for now");
}
f.update();
f->update();
};
scheduler::enqueue(detail::communication_stream(), std::move(task));
}
@ -74,22 +74,22 @@ void AllGather::eval_gpu(
out.set_data(allocator::malloc_or_wait(out.nbytes()));
Fence f{stream()};
auto f = std::make_shared<Fence>(stream());
if (in.event().valid()) {
f.update_gpu(in);
f->update_gpu(in);
}
f.wait_gpu(out);
f->wait_gpu(out);
auto task = [in = in,
out = unsafe_weak_copy(out),
f = std::move(f),
group = group()]() mutable {
if (in.event().valid()) {
f.wait();
f->wait();
}
distributed::detail::all_gather(group, in, out);
f.update();
f->update();
};
scheduler::enqueue(detail::communication_stream(), std::move(task));
}
@ -103,9 +103,9 @@ void Send::eval_gpu(
auto& in = inputs[0];
// Encode a signal event for the input
Fence f{stream()};
auto f = std::make_shared<Fence>(stream());
if (in.event().valid()) {
f.update_gpu(in);
f->update_gpu(in);
}
auto& out = outputs[0];
@ -118,7 +118,7 @@ void Send::eval_gpu(
group = group(),
dst = dst_]() mutable {
if (in.event().valid()) {
f.wait();
f->wait();
}
distributed::detail::send(group, out, dst);
};
@ -135,8 +135,8 @@ void Recv::eval_gpu(
out.set_data(allocator::malloc_or_wait(out.nbytes()));
Fence f{stream()};
f.wait_gpu(out);
auto f = std::make_shared<Fence>(stream());
f->wait_gpu(out);
// Schedule an async recv on the comm stream
auto task = [out = unsafe_weak_copy(out),
@ -144,7 +144,7 @@ void Recv::eval_gpu(
group = group(),
src = src_]() mutable {
distributed::detail::recv(group, out, src);
f.update();
f->update();
};
scheduler::enqueue(detail::communication_stream(), std::move(task));
}

View File

@ -35,7 +35,7 @@ Fence::Fence(const Stream& stream) : stream_(stream) {
}
Fence::~Fence() {
if (use_fast_) {
if (fence_ != nullptr && use_fast_) {
cpu_value()[0] = INT_MAX;
}
}

View File

@ -20,6 +20,9 @@ namespace mlx::core {
*/
class Fence {
public:
Fence(const Fence&) = delete;
Fence& operator=(const Fence&) = delete;
Fence(const Stream& stream);
~Fence();