diff --git a/mlx/backend/metal/scan.cpp b/mlx/backend/metal/scan.cpp index 46c2a9bea1..b31fdf4f9b 100644 --- a/mlx/backend/metal/scan.cpp +++ b/mlx/backend/metal/scan.cpp @@ -17,10 +17,10 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { auto& s = stream(); auto& d = metal::device(s.device); - std::vector copies; + bool donate = inputs[0].is_donatable(); auto in = inputs[0]; if (in.flags().contiguous && in.strides()[axis_] != 0) { - if (in.is_donatable() && in.itemsize() == out.itemsize()) { + if (donate && in.itemsize() == out.itemsize()) { out.move_shared_buffer(in); } else { out.set_data( @@ -32,8 +32,7 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { } else { array arr_copy(in.shape(), in.dtype(), nullptr, {}); copy_gpu(in, arr_copy, CopyType::General, s); - copies.push_back(arr_copy); - in = arr_copy; + in = std::move(arr_copy); out.move_shared_buffer(in); } @@ -127,8 +126,6 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { MTL::Size group_dims(thread_group_size, 1, 1); compute_encoder.dispatch_threads(grid_dims, group_dims); } - - d.add_temporaries(std::move(copies), s.index); } } // namespace mlx::core diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 515836a5ed..8e1cd8efde 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1894,6 +1894,22 @@ class TestOps(mlx_tests.MLXTestCase): expected = mx.repeat(expected[:, None], 2, axis=1) self.assertTrue(mx.array_equal(expected, out)) + # Test donation + def fn(its): + x = mx.ones((32,)) + for _ in range(its): + x = mx.cumsum(x) + return x + + mx.synchronize(mx.default_stream(mx.default_device())) + mx.eval(fn(2)) + mx.synchronize(mx.default_stream(mx.default_device())) + mem2 = mx.metal.get_peak_memory() + mx.eval(fn(4)) + mx.synchronize(mx.default_stream(mx.default_device())) + mem4 = mx.metal.get_peak_memory() + self.assertEqual(mem2, mem4) + def test_squeeze_expand(self): a = mx.zeros((2, 1, 2, 1)) self.assertEqual(mx.squeeze(a).shape, (2, 2))