fix donation in scan (#1917)

This commit is contained in:
Awni Hannun 2025-03-03 11:30:59 -08:00 committed by GitHub
parent ba12e4999a
commit 6bcd6bcf70
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 19 additions and 6 deletions

View File

@ -17,10 +17,10 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = stream(); auto& s = stream();
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
std::vector<array> copies; bool donate = inputs[0].is_donatable();
auto in = inputs[0]; auto in = inputs[0];
if (in.flags().contiguous && in.strides()[axis_] != 0) { if (in.flags().contiguous && in.strides()[axis_] != 0) {
if (in.is_donatable() && in.itemsize() == out.itemsize()) { if (donate && in.itemsize() == out.itemsize()) {
out.move_shared_buffer(in); out.move_shared_buffer(in);
} else { } else {
out.set_data( out.set_data(
@ -32,8 +32,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
} else { } else {
array arr_copy(in.shape(), in.dtype(), nullptr, {}); array arr_copy(in.shape(), in.dtype(), nullptr, {});
copy_gpu(in, arr_copy, CopyType::General, s); copy_gpu(in, arr_copy, CopyType::General, s);
copies.push_back(arr_copy); in = std::move(arr_copy);
in = arr_copy;
out.move_shared_buffer(in); out.move_shared_buffer(in);
} }
@ -127,8 +126,6 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
MTL::Size group_dims(thread_group_size, 1, 1); MTL::Size group_dims(thread_group_size, 1, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims); compute_encoder.dispatch_threads(grid_dims, group_dims);
} }
d.add_temporaries(std::move(copies), s.index);
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -1894,6 +1894,22 @@ class TestOps(mlx_tests.MLXTestCase):
expected = mx.repeat(expected[:, None], 2, axis=1) expected = mx.repeat(expected[:, None], 2, axis=1)
self.assertTrue(mx.array_equal(expected, out)) self.assertTrue(mx.array_equal(expected, out))
# Test donation
def fn(its):
x = mx.ones((32,))
for _ in range(its):
x = mx.cumsum(x)
return x
mx.synchronize(mx.default_stream(mx.default_device()))
mx.eval(fn(2))
mx.synchronize(mx.default_stream(mx.default_device()))
mem2 = mx.metal.get_peak_memory()
mx.eval(fn(4))
mx.synchronize(mx.default_stream(mx.default_device()))
mem4 = mx.metal.get_peak_memory()
self.assertEqual(mem2, mem4)
def test_squeeze_expand(self): def test_squeeze_expand(self):
a = mx.zeros((2, 1, 2, 1)) a = mx.zeros((2, 1, 2, 1))
self.assertEqual(mx.squeeze(a).shape, (2, 2)) self.assertEqual(mx.squeeze(a).shape, (2, 2))