diff --git a/mlx/array.cpp b/mlx/array.cpp index b06de8fa3..9db264317 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -76,6 +76,18 @@ array::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter) set_data(data, deleter); } +array::array( + allocator::Buffer data, + Shape shape, + Dtype dtype, + Strides strides, + size_t data_size, + Flags flags, + Deleter deleter) + : array_desc_(std::make_shared(std::move(shape), dtype)) { + set_data(data, data_size, std::move(strides), flags, deleter); +} + void array::detach() { for (auto& s : array_desc_->siblings) { s.array_desc_->inputs.clear(); diff --git a/mlx/array.h b/mlx/array.h index 2c1b35cbc..e5446088b 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -243,6 +243,18 @@ class array { bool col_contiguous : 1; }; + /** Build an array from all the info held by the array description. Including + * the buffer, strides, flags. + */ + explicit array( + allocator::Buffer data, + Shape shape, + Dtype dtype, + Strides strides, + size_t data_size, + Flags flags, + Deleter deleter = allocator::free); + /** The array's primitive. */ Primitive& primitive() const { return *(array_desc_->primitive); diff --git a/mlx/backend/metal/distributed.cpp b/mlx/backend/metal/distributed.cpp index 6f36e9484..9ca727ef4 100644 --- a/mlx/backend/metal/distributed.cpp +++ b/mlx/backend/metal/distributed.cpp @@ -7,6 +7,7 @@ #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/event.h" #include "mlx/backend/metal/fence.h" +#include "mlx/backend/metal/utils.h" #include "mlx/distributed/ops.h" #include "mlx/distributed/primitives.h" #include "mlx/scheduler.h" @@ -43,7 +44,7 @@ void AllReduce::eval_gpu( f.wait_gpu(out); auto task = [in = in, - out = out, + out = unsafe_weak_copy(out), f = std::move(f), reduce_type = reduce_type_, group = group()]() mutable { @@ -80,14 +81,16 @@ void AllGather::eval_gpu( } f.wait_gpu(out); - auto task = - [in = in, out = out, f = std::move(f), group = group()]() mutable { - if (in.event().valid()) { - f.wait(); - } - distributed::detail::all_gather(group, in, out); - f.update(); - }; + auto task = [in = in, + out = unsafe_weak_copy(out), + f = std::move(f), + group = group()]() mutable { + if (in.event().valid()) { + f.wait(); + } + distributed::detail::all_gather(group, in, out); + f.update(); + }; scheduler::enqueue(detail::communication_stream(), std::move(task)); } @@ -110,7 +113,7 @@ void Send::eval_gpu( // Schedule an async send on the comm stream auto task = [in = in, - out = out, + out = unsafe_weak_copy(out), f = std::move(f), group = group(), dst = dst_]() mutable { @@ -136,11 +139,13 @@ void Recv::eval_gpu( f.wait_gpu(out); // Schedule an async recv on the comm stream - auto task = - [out = out, f = std::move(f), group = group(), src = src_]() mutable { - distributed::detail::recv(group, out, src); - f.update(); - }; + auto task = [out = unsafe_weak_copy(out), + f = std::move(f), + group = group(), + src = src_]() mutable { + distributed::detail::recv(group, out, src); + f.update(); + }; scheduler::enqueue(detail::communication_stream(), std::move(task)); } diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index d07a66ee0..df8638012 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -303,7 +303,7 @@ void Unflatten::eval_gpu(const std::vector& inputs, array& out) { void Load::eval_gpu(const std::vector& inputs, array& out) { out.set_data(allocator::malloc_or_wait(out.nbytes())); - auto read_task = [out = out, + auto read_task = [out = unsafe_weak_copy(out), offset = offset_, reader = reader_, swap_endianness = swap_endianness_]() mutable { diff --git a/mlx/backend/metal/utils.h b/mlx/backend/metal/utils.h index cc56bab32..09415e999 100644 --- a/mlx/backend/metal/utils.h +++ b/mlx/backend/metal/utils.h @@ -69,4 +69,19 @@ void concatenate(std::string& acc, T first, Args... args) { concatenate(acc, args...); } +/** + * Get a new array that refers to the same data but has a non-owning pointer to + * them. + */ +inline array unsafe_weak_copy(const array& x) { + return array( + x.buffer(), + x.shape(), + x.dtype(), + x.strides(), + x.data_size(), + x.flags(), + [](auto b) {}); +} + } // namespace mlx::core diff --git a/python/tests/mpi_test_distributed.py b/python/tests/mpi_test_distributed.py index 79c58d8b6..d63033429 100644 --- a/python/tests/mpi_test_distributed.py +++ b/python/tests/mpi_test_distributed.py @@ -174,6 +174,21 @@ class TestDistributed(mlx_tests.MLXTestCase): finally: mx.distributed.all_sum = original_all_sum + def test_donation(self): + x = mx.random.normal((1024,)) + mx.eval(x) + + mx.metal.reset_peak_memory() + scale = mx.array(2.0) + y = mx.distributed.all_sum(x) + mx.eval(y) + all_sum_only = mx.metal.get_peak_memory() + y = mx.distributed.all_sum(x) * scale + mx.eval(y) + all_sum_with_binary = mx.metal.get_peak_memory() + + self.assertEqual(all_sum_only, all_sum_with_binary) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_load.py b/python/tests/test_load.py index 36425e418..0b42e4a6b 100644 --- a/python/tests/test_load.py +++ b/python/tests/test_load.py @@ -356,6 +356,23 @@ class TestLoad(mlx_tests.MLXTestCase): aload = mx.load(save_file)["a"] self.assertTrue(mx.array_equal(a, aload)) + def test_load_donation(self): + x = mx.random.normal((1024,)) + mx.eval(x) + save_file = os.path.join(self.test_dir, "donation.npy") + mx.save(save_file, x) + + mx.metal.reset_peak_memory() + scale = mx.array(2.0) + y = mx.load(save_file) + mx.eval(y) + load_only = mx.metal.get_peak_memory() + y = mx.load(save_file) * scale + mx.eval(y) + load_with_binary = mx.metal.get_peak_memory() + + self.assertEqual(load_only, load_with_binary) + if __name__ == "__main__": unittest.main()