mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Ensure we only have one copy of the fence
This commit is contained in:
parent
f4a5959055
commit
c5073fc452
@ -29,10 +29,10 @@ void AllReduce::eval_gpu(
|
|||||||
|
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
|
|
||||||
Fence f{stream()};
|
auto f = std::make_shared<Fence>(stream());
|
||||||
|
|
||||||
if (in.event().valid()) {
|
if (in.event().valid()) {
|
||||||
f.update_gpu(in);
|
f->update_gpu(in);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto& out = outputs[0];
|
auto& out = outputs[0];
|
||||||
@ -41,7 +41,7 @@ void AllReduce::eval_gpu(
|
|||||||
} else {
|
} else {
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
}
|
}
|
||||||
f.wait_gpu(out);
|
f->wait_gpu(out);
|
||||||
|
|
||||||
auto task = [in = in,
|
auto task = [in = in,
|
||||||
out = unsafe_weak_copy(out),
|
out = unsafe_weak_copy(out),
|
||||||
@ -49,7 +49,7 @@ void AllReduce::eval_gpu(
|
|||||||
reduce_type = reduce_type_,
|
reduce_type = reduce_type_,
|
||||||
group = group()]() mutable {
|
group = group()]() mutable {
|
||||||
if (in.event().valid()) {
|
if (in.event().valid()) {
|
||||||
f.wait();
|
f->wait();
|
||||||
}
|
}
|
||||||
switch (reduce_type) {
|
switch (reduce_type) {
|
||||||
case Sum:
|
case Sum:
|
||||||
@ -59,7 +59,7 @@ void AllReduce::eval_gpu(
|
|||||||
default:
|
default:
|
||||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
throw std::runtime_error("Only all reduce sum is supported for now");
|
||||||
}
|
}
|
||||||
f.update();
|
f->update();
|
||||||
};
|
};
|
||||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||||
}
|
}
|
||||||
@ -74,22 +74,22 @@ void AllGather::eval_gpu(
|
|||||||
|
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
Fence f{stream()};
|
auto f = std::make_shared<Fence>(stream());
|
||||||
|
|
||||||
if (in.event().valid()) {
|
if (in.event().valid()) {
|
||||||
f.update_gpu(in);
|
f->update_gpu(in);
|
||||||
}
|
}
|
||||||
f.wait_gpu(out);
|
f->wait_gpu(out);
|
||||||
|
|
||||||
auto task = [in = in,
|
auto task = [in = in,
|
||||||
out = unsafe_weak_copy(out),
|
out = unsafe_weak_copy(out),
|
||||||
f = std::move(f),
|
f = std::move(f),
|
||||||
group = group()]() mutable {
|
group = group()]() mutable {
|
||||||
if (in.event().valid()) {
|
if (in.event().valid()) {
|
||||||
f.wait();
|
f->wait();
|
||||||
}
|
}
|
||||||
distributed::detail::all_gather(group, in, out);
|
distributed::detail::all_gather(group, in, out);
|
||||||
f.update();
|
f->update();
|
||||||
};
|
};
|
||||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||||
}
|
}
|
||||||
@ -103,9 +103,9 @@ void Send::eval_gpu(
|
|||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
|
|
||||||
// Encode a signal event for the input
|
// Encode a signal event for the input
|
||||||
Fence f{stream()};
|
auto f = std::make_shared<Fence>(stream());
|
||||||
if (in.event().valid()) {
|
if (in.event().valid()) {
|
||||||
f.update_gpu(in);
|
f->update_gpu(in);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto& out = outputs[0];
|
auto& out = outputs[0];
|
||||||
@ -118,7 +118,7 @@ void Send::eval_gpu(
|
|||||||
group = group(),
|
group = group(),
|
||||||
dst = dst_]() mutable {
|
dst = dst_]() mutable {
|
||||||
if (in.event().valid()) {
|
if (in.event().valid()) {
|
||||||
f.wait();
|
f->wait();
|
||||||
}
|
}
|
||||||
distributed::detail::send(group, out, dst);
|
distributed::detail::send(group, out, dst);
|
||||||
};
|
};
|
||||||
@ -135,8 +135,8 @@ void Recv::eval_gpu(
|
|||||||
|
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
Fence f{stream()};
|
auto f = std::make_shared<Fence>(stream());
|
||||||
f.wait_gpu(out);
|
f->wait_gpu(out);
|
||||||
|
|
||||||
// Schedule an async recv on the comm stream
|
// Schedule an async recv on the comm stream
|
||||||
auto task = [out = unsafe_weak_copy(out),
|
auto task = [out = unsafe_weak_copy(out),
|
||||||
@ -144,7 +144,7 @@ void Recv::eval_gpu(
|
|||||||
group = group(),
|
group = group(),
|
||||||
src = src_]() mutable {
|
src = src_]() mutable {
|
||||||
distributed::detail::recv(group, out, src);
|
distributed::detail::recv(group, out, src);
|
||||||
f.update();
|
f->update();
|
||||||
};
|
};
|
||||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||||
}
|
}
|
||||||
|
@ -35,7 +35,7 @@ Fence::Fence(const Stream& stream) : stream_(stream) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Fence::~Fence() {
|
Fence::~Fence() {
|
||||||
if (use_fast_) {
|
if (fence_ != nullptr && use_fast_) {
|
||||||
cpu_value()[0] = INT_MAX;
|
cpu_value()[0] = INT_MAX;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -20,6 +20,9 @@ namespace mlx::core {
|
|||||||
*/
|
*/
|
||||||
class Fence {
|
class Fence {
|
||||||
public:
|
public:
|
||||||
|
Fence(const Fence&) = delete;
|
||||||
|
Fence& operator=(const Fence&) = delete;
|
||||||
|
|
||||||
Fence(const Stream& stream);
|
Fence(const Stream& stream);
|
||||||
~Fence();
|
~Fence();
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user