Fix synchronization bug for in stream async works (#1768)

This commit is contained in:
Awni Hannun 2025-01-15 06:07:34 -08:00 committed by GitHub
parent 33421c1dd3
commit f288db8d34
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 40 additions and 29 deletions

View File

@ -12,11 +12,11 @@
namespace mlx::core::distributed {
void signal_and_wait(const array& in, const array& out) {
if (in.event().valid()) {
encode_signal(in.event());
void signal_and_wait(const Event& e_signal, const Event& e_wait) {
if (e_signal.valid()) {
encode_signal(e_signal);
}
encode_wait(out.event());
encode_wait(e_wait);
}
void AllReduce::eval_gpu(
@ -33,8 +33,12 @@ void AllReduce::eval_gpu(
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
auto e = Event(stream());
e.set_value(1);
signal_and_wait(in.event(), e);
auto task = [in = in,
out = out,
e = std::move(e),
reduce_type = reduce_type_,
group = group()]() mutable {
if (in.event().valid()) {
@ -48,11 +52,9 @@ void AllReduce::eval_gpu(
default:
throw std::runtime_error("Only all reduce sum is supported for now");
}
out.event().signal();
e.signal();
};
scheduler::enqueue(detail::communication_stream(), std::move(task));
signal_and_wait(in, out);
}
void AllGather::eval_gpu(
@ -65,15 +67,19 @@ void AllGather::eval_gpu(
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto task = [in = in, out = out, group = group()]() mutable {
if (in.event().valid()) {
in.event().wait();
}
distributed::detail::all_gather(group, in, out);
out.event().signal();
};
auto e = Event(stream());
e.set_value(1);
signal_and_wait(in.event(), e);
auto task =
[in = in, out = out, e = std::move(e), group = group()]() mutable {
if (in.event().valid()) {
in.event().wait();
}
distributed::detail::all_gather(group, in, out);
e.signal();
};
scheduler::enqueue(detail::communication_stream(), std::move(task));
signal_and_wait(in, out);
}
void Send::eval_gpu(
@ -92,12 +98,10 @@ void Send::eval_gpu(
in.event().wait();
}
distributed::detail::send(group, out, dst);
out.event().signal();
};
scheduler::enqueue(detail::communication_stream(), std::move(task));
// Encode a signal event for the input but not a wait since we don't need to
// wait on the output.
// Encode a signal event for the input
if (in.event().valid()) {
encode_signal(in.event());
}
@ -113,15 +117,18 @@ void Recv::eval_gpu(
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// Schedule an async recv on the comm stream
auto task = [out = out, group = group(), src = src_]() mutable {
distributed::detail::recv(group, out, src);
out.event().signal();
};
scheduler::enqueue(detail::communication_stream(), std::move(task));
auto e = Event(stream());
e.set_value(1);
// Encode a wait event as there is no input for the recv to encode a signal.
encode_wait(out.event());
encode_wait(e);
// Schedule an async recv on the comm stream
auto task =
[out = out, e = std::move(e), group = group(), src = src_]() mutable {
distributed::detail::recv(group, out, src);
e.signal();
};
scheduler::enqueue(detail::communication_stream(), std::move(task));
}
} // namespace mlx::core::distributed

View File

@ -316,13 +316,17 @@ void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
read_task();
return;
}
auto fut = io::thread_pool().enqueue(std::move(read_task)).share();
auto signal_task = [out = out, fut = std::move(fut)]() {
auto e = Event(stream());
e.set_value(1);
encode_wait(e);
auto signal_task = [e = std::move(e), fut = std::move(fut)]() mutable {
fut.wait();
out.event().signal();
e.signal();
};
scheduler::enqueue(io_stream(), std::move(signal_task));
encode_wait(out.event());
}
void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {