mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Ensure we only have one copy of the fence
This commit is contained in:
parent
f4a5959055
commit
c5073fc452
@ -29,10 +29,10 @@ void AllReduce::eval_gpu(
|
||||
|
||||
auto& in = inputs[0];
|
||||
|
||||
Fence f{stream()};
|
||||
auto f = std::make_shared<Fence>(stream());
|
||||
|
||||
if (in.event().valid()) {
|
||||
f.update_gpu(in);
|
||||
f->update_gpu(in);
|
||||
}
|
||||
|
||||
auto& out = outputs[0];
|
||||
@ -41,7 +41,7 @@ void AllReduce::eval_gpu(
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
f.wait_gpu(out);
|
||||
f->wait_gpu(out);
|
||||
|
||||
auto task = [in = in,
|
||||
out = unsafe_weak_copy(out),
|
||||
@ -49,7 +49,7 @@ void AllReduce::eval_gpu(
|
||||
reduce_type = reduce_type_,
|
||||
group = group()]() mutable {
|
||||
if (in.event().valid()) {
|
||||
f.wait();
|
||||
f->wait();
|
||||
}
|
||||
switch (reduce_type) {
|
||||
case Sum:
|
||||
@ -59,7 +59,7 @@ void AllReduce::eval_gpu(
|
||||
default:
|
||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
||||
}
|
||||
f.update();
|
||||
f->update();
|
||||
};
|
||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||
}
|
||||
@ -74,22 +74,22 @@ void AllGather::eval_gpu(
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
Fence f{stream()};
|
||||
auto f = std::make_shared<Fence>(stream());
|
||||
|
||||
if (in.event().valid()) {
|
||||
f.update_gpu(in);
|
||||
f->update_gpu(in);
|
||||
}
|
||||
f.wait_gpu(out);
|
||||
f->wait_gpu(out);
|
||||
|
||||
auto task = [in = in,
|
||||
out = unsafe_weak_copy(out),
|
||||
f = std::move(f),
|
||||
group = group()]() mutable {
|
||||
if (in.event().valid()) {
|
||||
f.wait();
|
||||
f->wait();
|
||||
}
|
||||
distributed::detail::all_gather(group, in, out);
|
||||
f.update();
|
||||
f->update();
|
||||
};
|
||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||
}
|
||||
@ -103,9 +103,9 @@ void Send::eval_gpu(
|
||||
auto& in = inputs[0];
|
||||
|
||||
// Encode a signal event for the input
|
||||
Fence f{stream()};
|
||||
auto f = std::make_shared<Fence>(stream());
|
||||
if (in.event().valid()) {
|
||||
f.update_gpu(in);
|
||||
f->update_gpu(in);
|
||||
}
|
||||
|
||||
auto& out = outputs[0];
|
||||
@ -118,7 +118,7 @@ void Send::eval_gpu(
|
||||
group = group(),
|
||||
dst = dst_]() mutable {
|
||||
if (in.event().valid()) {
|
||||
f.wait();
|
||||
f->wait();
|
||||
}
|
||||
distributed::detail::send(group, out, dst);
|
||||
};
|
||||
@ -135,8 +135,8 @@ void Recv::eval_gpu(
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
Fence f{stream()};
|
||||
f.wait_gpu(out);
|
||||
auto f = std::make_shared<Fence>(stream());
|
||||
f->wait_gpu(out);
|
||||
|
||||
// Schedule an async recv on the comm stream
|
||||
auto task = [out = unsafe_weak_copy(out),
|
||||
@ -144,7 +144,7 @@ void Recv::eval_gpu(
|
||||
group = group(),
|
||||
src = src_]() mutable {
|
||||
distributed::detail::recv(group, out, src);
|
||||
f.update();
|
||||
f->update();
|
||||
};
|
||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||
}
|
||||
|
@ -35,7 +35,7 @@ Fence::Fence(const Stream& stream) : stream_(stream) {
|
||||
}
|
||||
|
||||
Fence::~Fence() {
|
||||
if (use_fast_) {
|
||||
if (fence_ != nullptr && use_fast_) {
|
||||
cpu_value()[0] = INT_MAX;
|
||||
}
|
||||
}
|
||||
|
@ -20,6 +20,9 @@ namespace mlx::core {
|
||||
*/
|
||||
class Fence {
|
||||
public:
|
||||
Fence(const Fence&) = delete;
|
||||
Fence& operator=(const Fence&) = delete;
|
||||
|
||||
Fence(const Stream& stream);
|
||||
~Fence();
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user