Fix synchronization bug for in stream async works (#1768)

This commit is contained in:
Awni Hannun 2025-01-15 06:07:34 -08:00 committed by GitHub
parent 33421c1dd3
commit f288db8d34
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 40 additions and 29 deletions

View File

@ -12,11 +12,11 @@
namespace mlx::core::distributed { namespace mlx::core::distributed {
void signal_and_wait(const array& in, const array& out) { void signal_and_wait(const Event& e_signal, const Event& e_wait) {
if (in.event().valid()) { if (e_signal.valid()) {
encode_signal(in.event()); encode_signal(e_signal);
} }
encode_wait(out.event()); encode_wait(e_wait);
} }
void AllReduce::eval_gpu( void AllReduce::eval_gpu(
@ -33,8 +33,12 @@ void AllReduce::eval_gpu(
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));
} }
auto e = Event(stream());
e.set_value(1);
signal_and_wait(in.event(), e);
auto task = [in = in, auto task = [in = in,
out = out, out = out,
e = std::move(e),
reduce_type = reduce_type_, reduce_type = reduce_type_,
group = group()]() mutable { group = group()]() mutable {
if (in.event().valid()) { if (in.event().valid()) {
@ -48,11 +52,9 @@ void AllReduce::eval_gpu(
default: default:
throw std::runtime_error("Only all reduce sum is supported for now"); throw std::runtime_error("Only all reduce sum is supported for now");
} }
out.event().signal(); e.signal();
}; };
scheduler::enqueue(detail::communication_stream(), std::move(task)); scheduler::enqueue(detail::communication_stream(), std::move(task));
signal_and_wait(in, out);
} }
void AllGather::eval_gpu( void AllGather::eval_gpu(
@ -65,15 +67,19 @@ void AllGather::eval_gpu(
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto task = [in = in, out = out, group = group()]() mutable { auto e = Event(stream());
e.set_value(1);
signal_and_wait(in.event(), e);
auto task =
[in = in, out = out, e = std::move(e), group = group()]() mutable {
if (in.event().valid()) { if (in.event().valid()) {
in.event().wait(); in.event().wait();
} }
distributed::detail::all_gather(group, in, out); distributed::detail::all_gather(group, in, out);
out.event().signal(); e.signal();
}; };
scheduler::enqueue(detail::communication_stream(), std::move(task)); scheduler::enqueue(detail::communication_stream(), std::move(task));
signal_and_wait(in, out);
} }
void Send::eval_gpu( void Send::eval_gpu(
@ -92,12 +98,10 @@ void Send::eval_gpu(
in.event().wait(); in.event().wait();
} }
distributed::detail::send(group, out, dst); distributed::detail::send(group, out, dst);
out.event().signal();
}; };
scheduler::enqueue(detail::communication_stream(), std::move(task)); scheduler::enqueue(detail::communication_stream(), std::move(task));
// Encode a signal event for the input but not a wait since we don't need to // Encode a signal event for the input
// wait on the output.
if (in.event().valid()) { if (in.event().valid()) {
encode_signal(in.event()); encode_signal(in.event());
} }
@ -113,15 +117,18 @@ void Recv::eval_gpu(
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto e = Event(stream());
e.set_value(1);
encode_wait(e);
// Schedule an async recv on the comm stream // Schedule an async recv on the comm stream
auto task = [out = out, group = group(), src = src_]() mutable { auto task =
[out = out, e = std::move(e), group = group(), src = src_]() mutable {
distributed::detail::recv(group, out, src); distributed::detail::recv(group, out, src);
out.event().signal(); e.signal();
}; };
scheduler::enqueue(detail::communication_stream(), std::move(task)); scheduler::enqueue(detail::communication_stream(), std::move(task));
// Encode a wait event as there is no input for the recv to encode a signal.
encode_wait(out.event());
} }
} // namespace mlx::core::distributed } // namespace mlx::core::distributed

View File

@ -316,13 +316,17 @@ void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
read_task(); read_task();
return; return;
} }
auto fut = io::thread_pool().enqueue(std::move(read_task)).share(); auto fut = io::thread_pool().enqueue(std::move(read_task)).share();
auto signal_task = [out = out, fut = std::move(fut)]() {
auto e = Event(stream());
e.set_value(1);
encode_wait(e);
auto signal_task = [e = std::move(e), fut = std::move(fut)]() mutable {
fut.wait(); fut.wait();
out.event().signal(); e.signal();
}; };
scheduler::enqueue(io_stream(), std::move(signal_task)); scheduler::enqueue(io_stream(), std::move(signal_task));
encode_wait(out.event());
} }
void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) { void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {