mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Fix synchronization bug for in stream async works (#1768)
This commit is contained in:
parent
33421c1dd3
commit
f288db8d34
@ -12,11 +12,11 @@
|
|||||||
|
|
||||||
namespace mlx::core::distributed {
|
namespace mlx::core::distributed {
|
||||||
|
|
||||||
void signal_and_wait(const array& in, const array& out) {
|
void signal_and_wait(const Event& e_signal, const Event& e_wait) {
|
||||||
if (in.event().valid()) {
|
if (e_signal.valid()) {
|
||||||
encode_signal(in.event());
|
encode_signal(e_signal);
|
||||||
}
|
}
|
||||||
encode_wait(out.event());
|
encode_wait(e_wait);
|
||||||
}
|
}
|
||||||
|
|
||||||
void AllReduce::eval_gpu(
|
void AllReduce::eval_gpu(
|
||||||
@ -33,8 +33,12 @@ void AllReduce::eval_gpu(
|
|||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto e = Event(stream());
|
||||||
|
e.set_value(1);
|
||||||
|
signal_and_wait(in.event(), e);
|
||||||
auto task = [in = in,
|
auto task = [in = in,
|
||||||
out = out,
|
out = out,
|
||||||
|
e = std::move(e),
|
||||||
reduce_type = reduce_type_,
|
reduce_type = reduce_type_,
|
||||||
group = group()]() mutable {
|
group = group()]() mutable {
|
||||||
if (in.event().valid()) {
|
if (in.event().valid()) {
|
||||||
@ -48,11 +52,9 @@ void AllReduce::eval_gpu(
|
|||||||
default:
|
default:
|
||||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
throw std::runtime_error("Only all reduce sum is supported for now");
|
||||||
}
|
}
|
||||||
out.event().signal();
|
e.signal();
|
||||||
};
|
};
|
||||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||||
|
|
||||||
signal_and_wait(in, out);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void AllGather::eval_gpu(
|
void AllGather::eval_gpu(
|
||||||
@ -65,15 +67,19 @@ void AllGather::eval_gpu(
|
|||||||
|
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
auto task = [in = in, out = out, group = group()]() mutable {
|
auto e = Event(stream());
|
||||||
if (in.event().valid()) {
|
e.set_value(1);
|
||||||
in.event().wait();
|
signal_and_wait(in.event(), e);
|
||||||
}
|
|
||||||
distributed::detail::all_gather(group, in, out);
|
auto task =
|
||||||
out.event().signal();
|
[in = in, out = out, e = std::move(e), group = group()]() mutable {
|
||||||
};
|
if (in.event().valid()) {
|
||||||
|
in.event().wait();
|
||||||
|
}
|
||||||
|
distributed::detail::all_gather(group, in, out);
|
||||||
|
e.signal();
|
||||||
|
};
|
||||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||||
signal_and_wait(in, out);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Send::eval_gpu(
|
void Send::eval_gpu(
|
||||||
@ -92,12 +98,10 @@ void Send::eval_gpu(
|
|||||||
in.event().wait();
|
in.event().wait();
|
||||||
}
|
}
|
||||||
distributed::detail::send(group, out, dst);
|
distributed::detail::send(group, out, dst);
|
||||||
out.event().signal();
|
|
||||||
};
|
};
|
||||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||||
|
|
||||||
// Encode a signal event for the input but not a wait since we don't need to
|
// Encode a signal event for the input
|
||||||
// wait on the output.
|
|
||||||
if (in.event().valid()) {
|
if (in.event().valid()) {
|
||||||
encode_signal(in.event());
|
encode_signal(in.event());
|
||||||
}
|
}
|
||||||
@ -113,15 +117,18 @@ void Recv::eval_gpu(
|
|||||||
|
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
// Schedule an async recv on the comm stream
|
auto e = Event(stream());
|
||||||
auto task = [out = out, group = group(), src = src_]() mutable {
|
e.set_value(1);
|
||||||
distributed::detail::recv(group, out, src);
|
|
||||||
out.event().signal();
|
|
||||||
};
|
|
||||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
|
||||||
|
|
||||||
// Encode a wait event as there is no input for the recv to encode a signal.
|
encode_wait(e);
|
||||||
encode_wait(out.event());
|
|
||||||
|
// Schedule an async recv on the comm stream
|
||||||
|
auto task =
|
||||||
|
[out = out, e = std::move(e), group = group(), src = src_]() mutable {
|
||||||
|
distributed::detail::recv(group, out, src);
|
||||||
|
e.signal();
|
||||||
|
};
|
||||||
|
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::distributed
|
} // namespace mlx::core::distributed
|
||||||
|
@ -316,13 +316,17 @@ void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
read_task();
|
read_task();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto fut = io::thread_pool().enqueue(std::move(read_task)).share();
|
auto fut = io::thread_pool().enqueue(std::move(read_task)).share();
|
||||||
auto signal_task = [out = out, fut = std::move(fut)]() {
|
|
||||||
|
auto e = Event(stream());
|
||||||
|
e.set_value(1);
|
||||||
|
encode_wait(e);
|
||||||
|
auto signal_task = [e = std::move(e), fut = std::move(fut)]() mutable {
|
||||||
fut.wait();
|
fut.wait();
|
||||||
out.event().signal();
|
e.signal();
|
||||||
};
|
};
|
||||||
scheduler::enqueue(io_stream(), std::move(signal_task));
|
scheduler::enqueue(io_stream(), std::move(signal_task));
|
||||||
encode_wait(out.event());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
Loading…
Reference in New Issue
Block a user