Compare commits

...

3 Commits

Author SHA1 Message Date
Angelos Katharopoulos
c5073fc452 Ensure we only have one copy of the fence 2025-03-04 23:37:15 -08:00
Angelos Katharopoulos
f4a5959055 Fix update gpu 2025-03-04 21:26:31 -08:00
Angelos Katharopoulos
90801467d8 Stop the fence in the destructor 2025-03-04 20:45:17 -08:00
4 changed files with 30 additions and 17 deletions

View File

@@ -29,10 +29,10 @@ void AllReduce::eval_gpu(
auto& in = inputs[0]; auto& in = inputs[0];
Fence f{stream()}; auto f = std::make_shared<Fence>(stream());
if (in.event().valid()) { if (in.event().valid()) {
f.update_gpu(in); f->update_gpu(in);
} }
auto& out = outputs[0]; auto& out = outputs[0];
@@ -41,7 +41,7 @@ void AllReduce::eval_gpu(
} else { } else {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));
} }
f.wait_gpu(out); f->wait_gpu(out);
auto task = [in = in, auto task = [in = in,
out = unsafe_weak_copy(out), out = unsafe_weak_copy(out),
@@ -49,7 +49,7 @@ void AllReduce::eval_gpu(
reduce_type = reduce_type_, reduce_type = reduce_type_,
group = group()]() mutable { group = group()]() mutable {
if (in.event().valid()) { if (in.event().valid()) {
f.wait(); f->wait();
} }
switch (reduce_type) { switch (reduce_type) {
case Sum: case Sum:
@@ -59,7 +59,7 @@ void AllReduce::eval_gpu(
default: default:
throw std::runtime_error("Only all reduce sum is supported for now"); throw std::runtime_error("Only all reduce sum is supported for now");
} }
f.update(); f->update();
}; };
scheduler::enqueue(detail::communication_stream(), std::move(task)); scheduler::enqueue(detail::communication_stream(), std::move(task));
} }
@@ -74,22 +74,22 @@ void AllGather::eval_gpu(
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));
Fence f{stream()}; auto f = std::make_shared<Fence>(stream());
if (in.event().valid()) { if (in.event().valid()) {
f.update_gpu(in); f->update_gpu(in);
} }
f.wait_gpu(out); f->wait_gpu(out);
auto task = [in = in, auto task = [in = in,
out = unsafe_weak_copy(out), out = unsafe_weak_copy(out),
f = std::move(f), f = std::move(f),
group = group()]() mutable { group = group()]() mutable {
if (in.event().valid()) { if (in.event().valid()) {
f.wait(); f->wait();
} }
distributed::detail::all_gather(group, in, out); distributed::detail::all_gather(group, in, out);
f.update(); f->update();
}; };
scheduler::enqueue(detail::communication_stream(), std::move(task)); scheduler::enqueue(detail::communication_stream(), std::move(task));
} }
@@ -103,9 +103,9 @@ void Send::eval_gpu(
auto& in = inputs[0]; auto& in = inputs[0];
// Encode a signal event for the input // Encode a signal event for the input
Fence f{stream()}; auto f = std::make_shared<Fence>(stream());
if (in.event().valid()) { if (in.event().valid()) {
f.update_gpu(in); f->update_gpu(in);
} }
auto& out = outputs[0]; auto& out = outputs[0];
@@ -118,7 +118,7 @@ void Send::eval_gpu(
group = group(), group = group(),
dst = dst_]() mutable { dst = dst_]() mutable {
if (in.event().valid()) { if (in.event().valid()) {
f.wait(); f->wait();
} }
distributed::detail::send(group, out, dst); distributed::detail::send(group, out, dst);
}; };
@@ -135,8 +135,8 @@ void Recv::eval_gpu(
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));
Fence f{stream()}; auto f = std::make_shared<Fence>(stream());
f.wait_gpu(out); f->wait_gpu(out);
// Schedule an async recv on the comm stream // Schedule an async recv on the comm stream
auto task = [out = unsafe_weak_copy(out), auto task = [out = unsafe_weak_copy(out),
@@ -144,7 +144,7 @@ void Recv::eval_gpu(
group = group(), group = group(),
src = src_]() mutable { src = src_]() mutable {
distributed::detail::recv(group, out, src); distributed::detail::recv(group, out, src);
f.update(); f->update();
}; };
scheduler::enqueue(detail::communication_stream(), std::move(task)); scheduler::enqueue(detail::communication_stream(), std::move(task));
} }

View File

@@ -34,6 +34,12 @@ Fence::Fence(const Stream& stream) : stream_(stream) {
} }
} }
Fence::~Fence() {
if (fence_ != nullptr && use_fast_) {
cpu_value()[0] = INT_MAX;
}
}
void Fence::wait_gpu(array& x) { void Fence::wait_gpu(array& x) {
gpu_count_++; gpu_count_++;
auto& d = metal::device(stream_.device); auto& d = metal::device(stream_.device);

View File

@@ -20,7 +20,11 @@ namespace mlx::core {
*/ */
class Fence { class Fence {
public: public:
Fence(const Fence&) = delete;
Fence& operator=(const Fence&) = delete;
Fence(const Stream& stream); Fence(const Stream& stream);
~Fence();
void update_gpu(const array& x); void update_gpu(const array& x);
void wait_gpu(array& x); void wait_gpu(array& x);

View File

@@ -29,7 +29,10 @@ constexpr constant metal::thread_scope thread_scope_system =
[[kernel]] void fence_update( [[kernel]] void fence_update(
volatile coherent(system) device uint* timestamp [[buffer(0)]], volatile coherent(system) device uint* timestamp [[buffer(0)]],
constant uint& value [[buffer(1)]]) { constant uint& value [[buffer(1)]]) {
timestamp[0] = value; uint tmp = timestamp[0];
if (tmp < value) {
timestamp[0] = value;
}
metal::atomic_thread_fence( metal::atomic_thread_fence(
metal::mem_flags::mem_device, metal::mem_flags::mem_device,
metal::memory_order_seq_cst, metal::memory_order_seq_cst,