mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
Compare commits
3 Commits
c35f4d089a
...
stop-fence
Author | SHA1 | Date | |
---|---|---|---|
![]() |
c5073fc452 | ||
![]() |
f4a5959055 | ||
![]() |
90801467d8 |
@@ -29,10 +29,10 @@ void AllReduce::eval_gpu(
|
||||
|
||||
auto& in = inputs[0];
|
||||
|
||||
Fence f{stream()};
|
||||
auto f = std::make_shared<Fence>(stream());
|
||||
|
||||
if (in.event().valid()) {
|
||||
f.update_gpu(in);
|
||||
f->update_gpu(in);
|
||||
}
|
||||
|
||||
auto& out = outputs[0];
|
||||
@@ -41,7 +41,7 @@ void AllReduce::eval_gpu(
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
f.wait_gpu(out);
|
||||
f->wait_gpu(out);
|
||||
|
||||
auto task = [in = in,
|
||||
out = unsafe_weak_copy(out),
|
||||
@@ -49,7 +49,7 @@ void AllReduce::eval_gpu(
|
||||
reduce_type = reduce_type_,
|
||||
group = group()]() mutable {
|
||||
if (in.event().valid()) {
|
||||
f.wait();
|
||||
f->wait();
|
||||
}
|
||||
switch (reduce_type) {
|
||||
case Sum:
|
||||
@@ -59,7 +59,7 @@ void AllReduce::eval_gpu(
|
||||
default:
|
||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
||||
}
|
||||
f.update();
|
||||
f->update();
|
||||
};
|
||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||
}
|
||||
@@ -74,22 +74,22 @@ void AllGather::eval_gpu(
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
Fence f{stream()};
|
||||
auto f = std::make_shared<Fence>(stream());
|
||||
|
||||
if (in.event().valid()) {
|
||||
f.update_gpu(in);
|
||||
f->update_gpu(in);
|
||||
}
|
||||
f.wait_gpu(out);
|
||||
f->wait_gpu(out);
|
||||
|
||||
auto task = [in = in,
|
||||
out = unsafe_weak_copy(out),
|
||||
f = std::move(f),
|
||||
group = group()]() mutable {
|
||||
if (in.event().valid()) {
|
||||
f.wait();
|
||||
f->wait();
|
||||
}
|
||||
distributed::detail::all_gather(group, in, out);
|
||||
f.update();
|
||||
f->update();
|
||||
};
|
||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||
}
|
||||
@@ -103,9 +103,9 @@ void Send::eval_gpu(
|
||||
auto& in = inputs[0];
|
||||
|
||||
// Encode a signal event for the input
|
||||
Fence f{stream()};
|
||||
auto f = std::make_shared<Fence>(stream());
|
||||
if (in.event().valid()) {
|
||||
f.update_gpu(in);
|
||||
f->update_gpu(in);
|
||||
}
|
||||
|
||||
auto& out = outputs[0];
|
||||
@@ -118,7 +118,7 @@ void Send::eval_gpu(
|
||||
group = group(),
|
||||
dst = dst_]() mutable {
|
||||
if (in.event().valid()) {
|
||||
f.wait();
|
||||
f->wait();
|
||||
}
|
||||
distributed::detail::send(group, out, dst);
|
||||
};
|
||||
@@ -135,8 +135,8 @@ void Recv::eval_gpu(
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
Fence f{stream()};
|
||||
f.wait_gpu(out);
|
||||
auto f = std::make_shared<Fence>(stream());
|
||||
f->wait_gpu(out);
|
||||
|
||||
// Schedule an async recv on the comm stream
|
||||
auto task = [out = unsafe_weak_copy(out),
|
||||
@@ -144,7 +144,7 @@ void Recv::eval_gpu(
|
||||
group = group(),
|
||||
src = src_]() mutable {
|
||||
distributed::detail::recv(group, out, src);
|
||||
f.update();
|
||||
f->update();
|
||||
};
|
||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||
}
|
||||
|
@@ -34,6 +34,12 @@ Fence::Fence(const Stream& stream) : stream_(stream) {
|
||||
}
|
||||
}
|
||||
|
||||
Fence::~Fence() {
|
||||
if (fence_ != nullptr && use_fast_) {
|
||||
cpu_value()[0] = INT_MAX;
|
||||
}
|
||||
}
|
||||
|
||||
void Fence::wait_gpu(array& x) {
|
||||
gpu_count_++;
|
||||
auto& d = metal::device(stream_.device);
|
||||
|
@@ -20,7 +20,11 @@ namespace mlx::core {
|
||||
*/
|
||||
class Fence {
|
||||
public:
|
||||
Fence(const Fence&) = delete;
|
||||
Fence& operator=(const Fence&) = delete;
|
||||
|
||||
Fence(const Stream& stream);
|
||||
~Fence();
|
||||
|
||||
void update_gpu(const array& x);
|
||||
void wait_gpu(array& x);
|
||||
|
@@ -29,7 +29,10 @@ constexpr constant metal::thread_scope thread_scope_system =
|
||||
[[kernel]] void fence_update(
|
||||
volatile coherent(system) device uint* timestamp [[buffer(0)]],
|
||||
constant uint& value [[buffer(1)]]) {
|
||||
timestamp[0] = value;
|
||||
uint tmp = timestamp[0];
|
||||
if (tmp < value) {
|
||||
timestamp[0] = value;
|
||||
}
|
||||
metal::atomic_thread_fence(
|
||||
metal::mem_flags::mem_device,
|
||||
metal::memory_order_seq_cst,
|
||||
|
Reference in New Issue
Block a user