Fixes output donation for IO ops on the GPU (#1857)

This commit is contained in:
Angelos Katharopoulos 2025-02-12 10:52:30 -08:00 committed by GitHub
parent 0a5215693e
commit 0145911bea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 92 additions and 16 deletions

View File

@ -76,6 +76,18 @@ array::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter)
set_data(data, deleter);
}
array::array(
allocator::Buffer data,
Shape shape,
Dtype dtype,
Strides strides,
size_t data_size,
Flags flags,
Deleter deleter)
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
set_data(data, data_size, std::move(strides), flags, deleter);
}
void array::detach() {
for (auto& s : array_desc_->siblings) {
s.array_desc_->inputs.clear();

View File

@ -243,6 +243,18 @@ class array {
bool col_contiguous : 1;
};
/** Build an array from all the info held by the array description. Including
* the buffer, strides, flags.
*/
explicit array(
allocator::Buffer data,
Shape shape,
Dtype dtype,
Strides strides,
size_t data_size,
Flags flags,
Deleter deleter = allocator::free);
/** The array's primitive. */
Primitive& primitive() const {
return *(array_desc_->primitive);

View File

@ -7,6 +7,7 @@
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/event.h"
#include "mlx/backend/metal/fence.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/distributed/ops.h"
#include "mlx/distributed/primitives.h"
#include "mlx/scheduler.h"
@ -43,7 +44,7 @@ void AllReduce::eval_gpu(
f.wait_gpu(out);
auto task = [in = in,
out = out,
out = unsafe_weak_copy(out),
f = std::move(f),
reduce_type = reduce_type_,
group = group()]() mutable {
@ -80,14 +81,16 @@ void AllGather::eval_gpu(
}
f.wait_gpu(out);
auto task =
[in = in, out = out, f = std::move(f), group = group()]() mutable {
if (in.event().valid()) {
f.wait();
}
distributed::detail::all_gather(group, in, out);
f.update();
};
auto task = [in = in,
out = unsafe_weak_copy(out),
f = std::move(f),
group = group()]() mutable {
if (in.event().valid()) {
f.wait();
}
distributed::detail::all_gather(group, in, out);
f.update();
};
scheduler::enqueue(detail::communication_stream(), std::move(task));
}
@ -110,7 +113,7 @@ void Send::eval_gpu(
// Schedule an async send on the comm stream
auto task = [in = in,
out = out,
out = unsafe_weak_copy(out),
f = std::move(f),
group = group(),
dst = dst_]() mutable {
@ -136,11 +139,13 @@ void Recv::eval_gpu(
f.wait_gpu(out);
// Schedule an async recv on the comm stream
auto task =
[out = out, f = std::move(f), group = group(), src = src_]() mutable {
distributed::detail::recv(group, out, src);
f.update();
};
auto task = [out = unsafe_weak_copy(out),
f = std::move(f),
group = group(),
src = src_]() mutable {
distributed::detail::recv(group, out, src);
f.update();
};
scheduler::enqueue(detail::communication_stream(), std::move(task));
}

View File

@ -303,7 +303,7 @@ void Unflatten::eval_gpu(const std::vector<array>& inputs, array& out) {
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto read_task = [out = out,
auto read_task = [out = unsafe_weak_copy(out),
offset = offset_,
reader = reader_,
swap_endianness = swap_endianness_]() mutable {

View File

@ -69,4 +69,19 @@ void concatenate(std::string& acc, T first, Args... args) {
concatenate(acc, args...);
}
/**
* Get a new array that refers to the same data but has a non-owning pointer to
* them.
*/
inline array unsafe_weak_copy(const array& x) {
return array(
x.buffer(),
x.shape(),
x.dtype(),
x.strides(),
x.data_size(),
x.flags(),
[](auto b) {});
}
} // namespace mlx::core

View File

@ -174,6 +174,21 @@ class TestDistributed(mlx_tests.MLXTestCase):
finally:
mx.distributed.all_sum = original_all_sum
def test_donation(self):
x = mx.random.normal((1024,))
mx.eval(x)
mx.metal.reset_peak_memory()
scale = mx.array(2.0)
y = mx.distributed.all_sum(x)
mx.eval(y)
all_sum_only = mx.metal.get_peak_memory()
y = mx.distributed.all_sum(x) * scale
mx.eval(y)
all_sum_with_binary = mx.metal.get_peak_memory()
self.assertEqual(all_sum_only, all_sum_with_binary)
if __name__ == "__main__":
unittest.main()

View File

@ -356,6 +356,23 @@ class TestLoad(mlx_tests.MLXTestCase):
aload = mx.load(save_file)["a"]
self.assertTrue(mx.array_equal(a, aload))
def test_load_donation(self):
x = mx.random.normal((1024,))
mx.eval(x)
save_file = os.path.join(self.test_dir, "donation.npy")
mx.save(save_file, x)
mx.metal.reset_peak_memory()
scale = mx.array(2.0)
y = mx.load(save_file)
mx.eval(y)
load_only = mx.metal.get_peak_memory()
y = mx.load(save_file) * scale
mx.eval(y)
load_with_binary = mx.metal.get_peak_memory()
self.assertEqual(load_only, load_with_binary)
if __name__ == "__main__":
unittest.main()