mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-07 09:14:34 +08:00
Compare commits
3 Commits
v0.24.1
...
stop-fence
Author | SHA1 | Date | |
---|---|---|---|
![]() |
c5073fc452 | ||
![]() |
f4a5959055 | ||
![]() |
90801467d8 |
@@ -29,10 +29,10 @@ void AllReduce::eval_gpu(
|
|||||||
|
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
|
|
||||||
Fence f{stream()};
|
auto f = std::make_shared<Fence>(stream());
|
||||||
|
|
||||||
if (in.event().valid()) {
|
if (in.event().valid()) {
|
||||||
f.update_gpu(in);
|
f->update_gpu(in);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto& out = outputs[0];
|
auto& out = outputs[0];
|
||||||
@@ -41,7 +41,7 @@ void AllReduce::eval_gpu(
|
|||||||
} else {
|
} else {
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
}
|
}
|
||||||
f.wait_gpu(out);
|
f->wait_gpu(out);
|
||||||
|
|
||||||
auto task = [in = in,
|
auto task = [in = in,
|
||||||
out = unsafe_weak_copy(out),
|
out = unsafe_weak_copy(out),
|
||||||
@@ -49,7 +49,7 @@ void AllReduce::eval_gpu(
|
|||||||
reduce_type = reduce_type_,
|
reduce_type = reduce_type_,
|
||||||
group = group()]() mutable {
|
group = group()]() mutable {
|
||||||
if (in.event().valid()) {
|
if (in.event().valid()) {
|
||||||
f.wait();
|
f->wait();
|
||||||
}
|
}
|
||||||
switch (reduce_type) {
|
switch (reduce_type) {
|
||||||
case Sum:
|
case Sum:
|
||||||
@@ -59,7 +59,7 @@ void AllReduce::eval_gpu(
|
|||||||
default:
|
default:
|
||||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
throw std::runtime_error("Only all reduce sum is supported for now");
|
||||||
}
|
}
|
||||||
f.update();
|
f->update();
|
||||||
};
|
};
|
||||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||||
}
|
}
|
||||||
@@ -74,22 +74,22 @@ void AllGather::eval_gpu(
|
|||||||
|
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
Fence f{stream()};
|
auto f = std::make_shared<Fence>(stream());
|
||||||
|
|
||||||
if (in.event().valid()) {
|
if (in.event().valid()) {
|
||||||
f.update_gpu(in);
|
f->update_gpu(in);
|
||||||
}
|
}
|
||||||
f.wait_gpu(out);
|
f->wait_gpu(out);
|
||||||
|
|
||||||
auto task = [in = in,
|
auto task = [in = in,
|
||||||
out = unsafe_weak_copy(out),
|
out = unsafe_weak_copy(out),
|
||||||
f = std::move(f),
|
f = std::move(f),
|
||||||
group = group()]() mutable {
|
group = group()]() mutable {
|
||||||
if (in.event().valid()) {
|
if (in.event().valid()) {
|
||||||
f.wait();
|
f->wait();
|
||||||
}
|
}
|
||||||
distributed::detail::all_gather(group, in, out);
|
distributed::detail::all_gather(group, in, out);
|
||||||
f.update();
|
f->update();
|
||||||
};
|
};
|
||||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||||
}
|
}
|
||||||
@@ -103,9 +103,9 @@ void Send::eval_gpu(
|
|||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
|
|
||||||
// Encode a signal event for the input
|
// Encode a signal event for the input
|
||||||
Fence f{stream()};
|
auto f = std::make_shared<Fence>(stream());
|
||||||
if (in.event().valid()) {
|
if (in.event().valid()) {
|
||||||
f.update_gpu(in);
|
f->update_gpu(in);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto& out = outputs[0];
|
auto& out = outputs[0];
|
||||||
@@ -118,7 +118,7 @@ void Send::eval_gpu(
|
|||||||
group = group(),
|
group = group(),
|
||||||
dst = dst_]() mutable {
|
dst = dst_]() mutable {
|
||||||
if (in.event().valid()) {
|
if (in.event().valid()) {
|
||||||
f.wait();
|
f->wait();
|
||||||
}
|
}
|
||||||
distributed::detail::send(group, out, dst);
|
distributed::detail::send(group, out, dst);
|
||||||
};
|
};
|
||||||
@@ -135,8 +135,8 @@ void Recv::eval_gpu(
|
|||||||
|
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
Fence f{stream()};
|
auto f = std::make_shared<Fence>(stream());
|
||||||
f.wait_gpu(out);
|
f->wait_gpu(out);
|
||||||
|
|
||||||
// Schedule an async recv on the comm stream
|
// Schedule an async recv on the comm stream
|
||||||
auto task = [out = unsafe_weak_copy(out),
|
auto task = [out = unsafe_weak_copy(out),
|
||||||
@@ -144,7 +144,7 @@ void Recv::eval_gpu(
|
|||||||
group = group(),
|
group = group(),
|
||||||
src = src_]() mutable {
|
src = src_]() mutable {
|
||||||
distributed::detail::recv(group, out, src);
|
distributed::detail::recv(group, out, src);
|
||||||
f.update();
|
f->update();
|
||||||
};
|
};
|
||||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||||
}
|
}
|
||||||
|
@@ -34,6 +34,12 @@ Fence::Fence(const Stream& stream) : stream_(stream) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Fence::~Fence() {
|
||||||
|
if (fence_ != nullptr && use_fast_) {
|
||||||
|
cpu_value()[0] = INT_MAX;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void Fence::wait_gpu(array& x) {
|
void Fence::wait_gpu(array& x) {
|
||||||
gpu_count_++;
|
gpu_count_++;
|
||||||
auto& d = metal::device(stream_.device);
|
auto& d = metal::device(stream_.device);
|
||||||
|
@@ -20,7 +20,11 @@ namespace mlx::core {
|
|||||||
*/
|
*/
|
||||||
class Fence {
|
class Fence {
|
||||||
public:
|
public:
|
||||||
|
Fence(const Fence&) = delete;
|
||||||
|
Fence& operator=(const Fence&) = delete;
|
||||||
|
|
||||||
Fence(const Stream& stream);
|
Fence(const Stream& stream);
|
||||||
|
~Fence();
|
||||||
|
|
||||||
void update_gpu(const array& x);
|
void update_gpu(const array& x);
|
||||||
void wait_gpu(array& x);
|
void wait_gpu(array& x);
|
||||||
|
@@ -29,7 +29,10 @@ constexpr constant metal::thread_scope thread_scope_system =
|
|||||||
[[kernel]] void fence_update(
|
[[kernel]] void fence_update(
|
||||||
volatile coherent(system) device uint* timestamp [[buffer(0)]],
|
volatile coherent(system) device uint* timestamp [[buffer(0)]],
|
||||||
constant uint& value [[buffer(1)]]) {
|
constant uint& value [[buffer(1)]]) {
|
||||||
timestamp[0] = value;
|
uint tmp = timestamp[0];
|
||||||
|
if (tmp < value) {
|
||||||
|
timestamp[0] = value;
|
||||||
|
}
|
||||||
metal::atomic_thread_fence(
|
metal::atomic_thread_fence(
|
||||||
metal::mem_flags::mem_device,
|
metal::mem_flags::mem_device,
|
||||||
metal::memory_order_seq_cst,
|
metal::memory_order_seq_cst,
|
||||||
|
Reference in New Issue
Block a user