Fixes output donation for IO ops on the GPU (#1857)

This commit is contained in:
Angelos Katharopoulos
2025-02-12 10:52:30 -08:00
committed by GitHub
parent 0a5215693e
commit 0145911bea
7 changed files with 92 additions and 16 deletions

View File

@@ -174,6 +174,21 @@ class TestDistributed(mlx_tests.MLXTestCase):
finally:
mx.distributed.all_sum = original_all_sum
def test_donation(self):
x = mx.random.normal((1024,))
mx.eval(x)
mx.metal.reset_peak_memory()
scale = mx.array(2.0)
y = mx.distributed.all_sum(x)
mx.eval(y)
all_sum_only = mx.metal.get_peak_memory()
y = mx.distributed.all_sum(x) * scale
mx.eval(y)
all_sum_with_binary = mx.metal.get_peak_memory()
self.assertEqual(all_sum_only, all_sum_with_binary)
if __name__ == "__main__":
unittest.main()

View File

@@ -356,6 +356,23 @@ class TestLoad(mlx_tests.MLXTestCase):
aload = mx.load(save_file)["a"]
self.assertTrue(mx.array_equal(a, aload))
def test_load_donation(self):
x = mx.random.normal((1024,))
mx.eval(x)
save_file = os.path.join(self.test_dir, "donation.npy")
mx.save(save_file, x)
mx.metal.reset_peak_memory()
scale = mx.array(2.0)
y = mx.load(save_file)
mx.eval(y)
load_only = mx.metal.get_peak_memory()
y = mx.load(save_file) * scale
mx.eval(y)
load_with_binary = mx.metal.get_peak_memory()
self.assertEqual(load_only, load_with_binary)
if __name__ == "__main__":
unittest.main()