mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Fix synchronization bug for in stream async works (#1768)
This commit is contained in:
parent
33421c1dd3
commit
f288db8d34
@ -12,11 +12,11 @@
|
||||
|
||||
namespace mlx::core::distributed {
|
||||
|
||||
void signal_and_wait(const array& in, const array& out) {
|
||||
if (in.event().valid()) {
|
||||
encode_signal(in.event());
|
||||
void signal_and_wait(const Event& e_signal, const Event& e_wait) {
|
||||
if (e_signal.valid()) {
|
||||
encode_signal(e_signal);
|
||||
}
|
||||
encode_wait(out.event());
|
||||
encode_wait(e_wait);
|
||||
}
|
||||
|
||||
void AllReduce::eval_gpu(
|
||||
@ -33,8 +33,12 @@ void AllReduce::eval_gpu(
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
|
||||
auto e = Event(stream());
|
||||
e.set_value(1);
|
||||
signal_and_wait(in.event(), e);
|
||||
auto task = [in = in,
|
||||
out = out,
|
||||
e = std::move(e),
|
||||
reduce_type = reduce_type_,
|
||||
group = group()]() mutable {
|
||||
if (in.event().valid()) {
|
||||
@ -48,11 +52,9 @@ void AllReduce::eval_gpu(
|
||||
default:
|
||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
||||
}
|
||||
out.event().signal();
|
||||
e.signal();
|
||||
};
|
||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||
|
||||
signal_and_wait(in, out);
|
||||
}
|
||||
|
||||
void AllGather::eval_gpu(
|
||||
@ -65,15 +67,19 @@ void AllGather::eval_gpu(
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
auto task = [in = in, out = out, group = group()]() mutable {
|
||||
if (in.event().valid()) {
|
||||
in.event().wait();
|
||||
}
|
||||
distributed::detail::all_gather(group, in, out);
|
||||
out.event().signal();
|
||||
};
|
||||
auto e = Event(stream());
|
||||
e.set_value(1);
|
||||
signal_and_wait(in.event(), e);
|
||||
|
||||
auto task =
|
||||
[in = in, out = out, e = std::move(e), group = group()]() mutable {
|
||||
if (in.event().valid()) {
|
||||
in.event().wait();
|
||||
}
|
||||
distributed::detail::all_gather(group, in, out);
|
||||
e.signal();
|
||||
};
|
||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||
signal_and_wait(in, out);
|
||||
}
|
||||
|
||||
void Send::eval_gpu(
|
||||
@ -92,12 +98,10 @@ void Send::eval_gpu(
|
||||
in.event().wait();
|
||||
}
|
||||
distributed::detail::send(group, out, dst);
|
||||
out.event().signal();
|
||||
};
|
||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||
|
||||
// Encode a signal event for the input but not a wait since we don't need to
|
||||
// wait on the output.
|
||||
// Encode a signal event for the input
|
||||
if (in.event().valid()) {
|
||||
encode_signal(in.event());
|
||||
}
|
||||
@ -113,15 +117,18 @@ void Recv::eval_gpu(
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
// Schedule an async recv on the comm stream
|
||||
auto task = [out = out, group = group(), src = src_]() mutable {
|
||||
distributed::detail::recv(group, out, src);
|
||||
out.event().signal();
|
||||
};
|
||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||
auto e = Event(stream());
|
||||
e.set_value(1);
|
||||
|
||||
// Encode a wait event as there is no input for the recv to encode a signal.
|
||||
encode_wait(out.event());
|
||||
encode_wait(e);
|
||||
|
||||
// Schedule an async recv on the comm stream
|
||||
auto task =
|
||||
[out = out, e = std::move(e), group = group(), src = src_]() mutable {
|
||||
distributed::detail::recv(group, out, src);
|
||||
e.signal();
|
||||
};
|
||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||
}
|
||||
|
||||
} // namespace mlx::core::distributed
|
||||
|
@ -316,13 +316,17 @@ void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
read_task();
|
||||
return;
|
||||
}
|
||||
|
||||
auto fut = io::thread_pool().enqueue(std::move(read_task)).share();
|
||||
auto signal_task = [out = out, fut = std::move(fut)]() {
|
||||
|
||||
auto e = Event(stream());
|
||||
e.set_value(1);
|
||||
encode_wait(e);
|
||||
auto signal_task = [e = std::move(e), fut = std::move(fut)]() mutable {
|
||||
fut.wait();
|
||||
out.event().signal();
|
||||
e.signal();
|
||||
};
|
||||
scheduler::enqueue(io_stream(), std::move(signal_task));
|
||||
encode_wait(out.event());
|
||||
}
|
||||
|
||||
void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
Loading…
Reference in New Issue
Block a user