From 9bd03dd9b43d1c842722103b1c3da3f83c21d5b8 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 18 Nov 2024 08:35:41 -0800 Subject: [PATCH] More buffer donation with no-ops (#1591) * more donation * fix test * fix build --- mlx/backend/common/common.cpp | 16 +++++++------- mlx/backend/common/slicing.cpp | 2 +- mlx/backend/common/utils.cpp | 22 +++++++++++++++++++ mlx/backend/common/utils.h | 9 ++++++++ mlx/backend/metal/primitives.cpp | 7 +++--- mlx/backend/metal/unary.cpp | 2 +- python/tests/test_eval.py | 37 ++++++++++++++++++++++++++++++++ 7 files changed, 82 insertions(+), 13 deletions(-) diff --git a/mlx/backend/common/common.cpp b/mlx/backend/common/common.cpp index a01d94eb3..fba9dc15b 100644 --- a/mlx/backend/common/common.cpp +++ b/mlx/backend/common/common.cpp @@ -39,7 +39,7 @@ void AsStrided::eval(const std::vector& inputs, array& out) { // rely on data_size anyway. size_t data_size = out.size(); - return out.copy_shared_buffer(in, strides_, flags, data_size, offset_); + return move_or_copy(in, out, strides_, flags, data_size, offset_); } void Broadcast::eval(const std::vector& inputs, array& out) { @@ -58,12 +58,12 @@ void Broadcast::eval(const std::vector& inputs, array& out) { if (out.size() > in.size()) { flags.row_contiguous = flags.col_contiguous = false; } - out.copy_shared_buffer(in, strides, flags, in.data_size()); + move_or_copy(in, out, strides, flags, in.data_size()); } void Copy::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); - out.copy_shared_buffer(inputs[0]); + move_or_copy(inputs[0], out); } void CustomTransforms::eval( @@ -72,7 +72,7 @@ void CustomTransforms::eval( assert(inputs.size() > outputs.size()); for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size(); i++, j++) { - outputs[i].copy_shared_buffer(inputs[j]); + move_or_copy(inputs[j], outputs[i]); } } @@ -81,7 +81,7 @@ void Depends::eval( std::vector& outputs) { assert(inputs.size() > outputs.size()); for (int i = 0; i < outputs.size(); i++) { - outputs[i].copy_shared_buffer(inputs[i]); + move_or_copy(inputs[i], outputs[i]); } } @@ -194,7 +194,7 @@ void Reshape::shared_buffer_reshape( auto max_dim = std::max_element(out.shape().begin(), out.shape().end()); flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim; } - out.copy_shared_buffer(in, out_strides, flags, in.data_size()); + move_or_copy(in, out, out_strides, flags, in.data_size()); } void Split::eval( @@ -263,7 +263,7 @@ std::tuple> SliceUpdate::prepare_slice( void StopGradient::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); - out.copy_shared_buffer(inputs[0]); + move_or_copy(inputs[0], out); } void Transpose::eval(const std::vector& inputs, array& out) { @@ -297,7 +297,7 @@ void Transpose::eval(const std::vector& inputs, array& out) { b_stride *= out.shape(ri); } } - out.copy_shared_buffer(in, out_strides, flags, in.data_size()); + move_or_copy(in, out, out_strides, flags, in.data_size()); } } // namespace mlx::core diff --git a/mlx/backend/common/slicing.cpp b/mlx/backend/common/slicing.cpp index 015cfad5f..343f0ff57 100644 --- a/mlx/backend/common/slicing.cpp +++ b/mlx/backend/common/slicing.cpp @@ -34,7 +34,7 @@ void shared_buffer_slice( flags.col_contiguous = is_col_contiguous; flags.contiguous = (no_bsx_size == data_size); - out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset); + move_or_copy(in, out, out_strides, flags, data_size, data_offset); } } // namespace mlx::core diff --git a/mlx/backend/common/utils.cpp b/mlx/backend/common/utils.cpp index 05ee47566..97fdfe968 100644 --- a/mlx/backend/common/utils.cpp +++ b/mlx/backend/common/utils.cpp @@ -4,6 +4,28 @@ namespace mlx::core { +void move_or_copy(const array& in, array& out) { + if (in.is_donatable()) { + out.move_shared_buffer(in); + } else { + out.copy_shared_buffer(in); + } +} + +void move_or_copy( + const array& in, + array& out, + const std::vector& strides, + array::Flags flags, + size_t data_size, + size_t offset /* = 0 */) { + if (in.is_donatable()) { + out.move_shared_buffer(in, strides, flags, data_size, offset); + } else { + out.copy_shared_buffer(in, strides, flags, data_size, offset); + } +} + template std::tuple, std::vector>> collapse_contiguous_dims_impl( diff --git a/mlx/backend/common/utils.h b/mlx/backend/common/utils.h index 7ce38b908..3d466ed51 100644 --- a/mlx/backend/common/utils.h +++ b/mlx/backend/common/utils.h @@ -178,4 +178,13 @@ inline bool is_donatable(const array& in, const array& out) { in.buffer_size() <= out.nbytes() + donation_extra; } +void move_or_copy(const array& in, array& out); +void move_or_copy( + const array& in, + array& out, + const std::vector& strides, + array::Flags flags, + size_t data_size, + size_t offset = 0); + } // namespace mlx::core diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 732b40edf..d09176f20 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -5,6 +5,7 @@ #include #include "mlx/backend/common/load.h" +#include "mlx/backend/common/utils.h" #include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels.h" @@ -343,7 +344,7 @@ void SliceUpdate::eval_gpu(const std::vector& inputs, array& out) { auto& upd = inputs[1]; if (upd.size() == 0) { - out.copy_shared_buffer(in); + move_or_copy(in, out); return; } @@ -420,8 +421,8 @@ void View::eval_gpu(const std::vector& inputs, array& out) { strides[i] *= ibytes; strides[i] /= obytes; } - out.copy_shared_buffer( - in, strides, in.flags(), in.data_size() * ibytes / obytes); + move_or_copy( + in, out, strides, in.flags(), in.data_size() * ibytes / obytes); } else { auto tmp = array(in.shape(), in.dtype(), nullptr, {}); tmp.set_data(allocator::malloc_or_wait(tmp.nbytes())); diff --git a/mlx/backend/metal/unary.cpp b/mlx/backend/metal/unary.cpp index 8d23d4192..e9baad065 100644 --- a/mlx/backend/metal/unary.cpp +++ b/mlx/backend/metal/unary.cpp @@ -161,7 +161,7 @@ void Round::eval_gpu(const std::vector& inputs, array& out) { unary_op_gpu(inputs, out, get_primitive_string(this)); } else { // No-op integer types - out.copy_shared_buffer(in); + move_or_copy(in, out); } } diff --git a/python/tests/test_eval.py b/python/tests/test_eval.py index 5856d84fd..37e31f80b 100644 --- a/python/tests/test_eval.py +++ b/python/tests/test_eval.py @@ -137,6 +137,43 @@ class TestEval(mlx_tests.MLXTestCase): mx.async_eval(x) mx.eval(a + b) + @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") + def test_donation_for_noops(self): + def fun(x): + s = x.shape + for _ in range(10): + x = mx.abs(x) + x = mx.reshape(x, (-1,)) + x = x.T.T + x = mx.stop_gradient(x) + x = mx.abs(x) + return x + + x = mx.zeros((4096, 4096)) + mx.eval(x) + pre = mx.metal.get_peak_memory() + out = fun(x) + del x + mx.eval(out) + post = mx.metal.get_peak_memory() + self.assertEqual(pre, post) + + def fun(x): + for _ in range(10): + x = mx.abs(x) + x = x[:-1] + x = mx.abs(x) + return x + + x = mx.zeros((4096 * 4096,)) + mx.eval(x) + pre = mx.metal.get_peak_memory() + out = fun(x) + del x + mx.eval(out) + post = mx.metal.get_peak_memory() + self.assertEqual(pre, post) + if __name__ == "__main__": unittest.main()