Fixes output donation for IO ops on the GPU (#1857)

This commit is contained in:
Angelos Katharopoulos
2025-02-12 10:52:30 -08:00
committed by GitHub
parent 0a5215693e
commit 0145911bea
7 changed files with 92 additions and 16 deletions

View File

@@ -7,6 +7,7 @@
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/event.h"
#include "mlx/backend/metal/fence.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/distributed/ops.h"
#include "mlx/distributed/primitives.h"
#include "mlx/scheduler.h"
@@ -43,7 +44,7 @@ void AllReduce::eval_gpu(
f.wait_gpu(out);
auto task = [in = in,
out = out,
out = unsafe_weak_copy(out),
f = std::move(f),
reduce_type = reduce_type_,
group = group()]() mutable {
@@ -80,14 +81,16 @@ void AllGather::eval_gpu(
}
f.wait_gpu(out);
auto task =
[in = in, out = out, f = std::move(f), group = group()]() mutable {
if (in.event().valid()) {
f.wait();
}
distributed::detail::all_gather(group, in, out);
f.update();
};
auto task = [in = in,
out = unsafe_weak_copy(out),
f = std::move(f),
group = group()]() mutable {
if (in.event().valid()) {
f.wait();
}
distributed::detail::all_gather(group, in, out);
f.update();
};
scheduler::enqueue(detail::communication_stream(), std::move(task));
}
@@ -110,7 +113,7 @@ void Send::eval_gpu(
// Schedule an async send on the comm stream
auto task = [in = in,
out = out,
out = unsafe_weak_copy(out),
f = std::move(f),
group = group(),
dst = dst_]() mutable {
@@ -136,11 +139,13 @@ void Recv::eval_gpu(
f.wait_gpu(out);
// Schedule an async recv on the comm stream
auto task =
[out = out, f = std::move(f), group = group(), src = src_]() mutable {
distributed::detail::recv(group, out, src);
f.update();
};
auto task = [out = unsafe_weak_copy(out),
f = std::move(f),
group = group(),
src = src_]() mutable {
distributed::detail::recv(group, out, src);
f.update();
};
scheduler::enqueue(detail::communication_stream(), std::move(task));
}

View File

@@ -303,7 +303,7 @@ void Unflatten::eval_gpu(const std::vector<array>& inputs, array& out) {
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto read_task = [out = out,
auto read_task = [out = unsafe_weak_copy(out),
offset = offset_,
reader = reader_,
swap_endianness = swap_endianness_]() mutable {

View File

@@ -69,4 +69,19 @@ void concatenate(std::string& acc, T first, Args... args) {
concatenate(acc, args...);
}
/**
* Get a new array that refers to the same data but has a non-owning pointer to
* them.
*/
inline array unsafe_weak_copy(const array& x) {
return array(
x.buffer(),
x.shape(),
x.dtype(),
x.strides(),
x.data_size(),
x.flags(),
[](auto b) {});
}
} // namespace mlx::core