Compare commits

...

3 Commits

Author SHA1 Message Date
Angelos Katharopoulos
c5073fc452 Ensure we only have one copy of the fence 2025-03-04 23:37:15 -08:00
Angelos Katharopoulos
f4a5959055 Fix update gpu 2025-03-04 21:26:31 -08:00
Angelos Katharopoulos
90801467d8 Stop the fence in the destructor 2025-03-04 20:45:17 -08:00
4 changed files with 30 additions and 17 deletions

View File

@@ -29,10 +29,10 @@ void AllReduce::eval_gpu(
auto& in = inputs[0];
Fence f{stream()};
auto f = std::make_shared<Fence>(stream());
if (in.event().valid()) {
f.update_gpu(in);
f->update_gpu(in);
}
auto& out = outputs[0];
@@ -41,7 +41,7 @@ void AllReduce::eval_gpu(
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
f.wait_gpu(out);
f->wait_gpu(out);
auto task = [in = in,
out = unsafe_weak_copy(out),
@@ -49,7 +49,7 @@ void AllReduce::eval_gpu(
reduce_type = reduce_type_,
group = group()]() mutable {
if (in.event().valid()) {
f.wait();
f->wait();
}
switch (reduce_type) {
case Sum:
@@ -59,7 +59,7 @@ void AllReduce::eval_gpu(
default:
throw std::runtime_error("Only all reduce sum is supported for now");
}
f.update();
f->update();
};
scheduler::enqueue(detail::communication_stream(), std::move(task));
}
@@ -74,22 +74,22 @@ void AllGather::eval_gpu(
out.set_data(allocator::malloc_or_wait(out.nbytes()));
Fence f{stream()};
auto f = std::make_shared<Fence>(stream());
if (in.event().valid()) {
f.update_gpu(in);
f->update_gpu(in);
}
f.wait_gpu(out);
f->wait_gpu(out);
auto task = [in = in,
out = unsafe_weak_copy(out),
f = std::move(f),
group = group()]() mutable {
if (in.event().valid()) {
f.wait();
f->wait();
}
distributed::detail::all_gather(group, in, out);
f.update();
f->update();
};
scheduler::enqueue(detail::communication_stream(), std::move(task));
}
@@ -103,9 +103,9 @@ void Send::eval_gpu(
auto& in = inputs[0];
// Encode a signal event for the input
Fence f{stream()};
auto f = std::make_shared<Fence>(stream());
if (in.event().valid()) {
f.update_gpu(in);
f->update_gpu(in);
}
auto& out = outputs[0];
@@ -118,7 +118,7 @@ void Send::eval_gpu(
group = group(),
dst = dst_]() mutable {
if (in.event().valid()) {
f.wait();
f->wait();
}
distributed::detail::send(group, out, dst);
};
@@ -135,8 +135,8 @@ void Recv::eval_gpu(
out.set_data(allocator::malloc_or_wait(out.nbytes()));
Fence f{stream()};
f.wait_gpu(out);
auto f = std::make_shared<Fence>(stream());
f->wait_gpu(out);
// Schedule an async recv on the comm stream
auto task = [out = unsafe_weak_copy(out),
@@ -144,7 +144,7 @@ void Recv::eval_gpu(
group = group(),
src = src_]() mutable {
distributed::detail::recv(group, out, src);
f.update();
f->update();
};
scheduler::enqueue(detail::communication_stream(), std::move(task));
}

View File

@@ -34,6 +34,12 @@ Fence::Fence(const Stream& stream) : stream_(stream) {
}
}
Fence::~Fence() {
if (fence_ != nullptr && use_fast_) {
cpu_value()[0] = INT_MAX;
}
}
void Fence::wait_gpu(array& x) {
gpu_count_++;
auto& d = metal::device(stream_.device);

View File

@@ -20,7 +20,11 @@ namespace mlx::core {
*/
class Fence {
public:
Fence(const Fence&) = delete;
Fence& operator=(const Fence&) = delete;
Fence(const Stream& stream);
~Fence();
void update_gpu(const array& x);
void wait_gpu(array& x);

View File

@@ -29,7 +29,10 @@ constexpr constant metal::thread_scope thread_scope_system =
[[kernel]] void fence_update(
volatile coherent(system) device uint* timestamp [[buffer(0)]],
constant uint& value [[buffer(1)]]) {
timestamp[0] = value;
uint tmp = timestamp[0];
if (tmp < value) {
timestamp[0] = value;
}
metal::atomic_thread_fence(
metal::mem_flags::mem_device,
metal::memory_order_seq_cst,