mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Fixes output donation for IO ops on the GPU (#1857)
This commit is contained in:

committed by
GitHub

parent
0a5215693e
commit
0145911bea
@@ -174,6 +174,21 @@ class TestDistributed(mlx_tests.MLXTestCase):
|
||||
finally:
|
||||
mx.distributed.all_sum = original_all_sum
|
||||
|
||||
def test_donation(self):
|
||||
x = mx.random.normal((1024,))
|
||||
mx.eval(x)
|
||||
|
||||
mx.metal.reset_peak_memory()
|
||||
scale = mx.array(2.0)
|
||||
y = mx.distributed.all_sum(x)
|
||||
mx.eval(y)
|
||||
all_sum_only = mx.metal.get_peak_memory()
|
||||
y = mx.distributed.all_sum(x) * scale
|
||||
mx.eval(y)
|
||||
all_sum_with_binary = mx.metal.get_peak_memory()
|
||||
|
||||
self.assertEqual(all_sum_only, all_sum_with_binary)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@@ -356,6 +356,23 @@ class TestLoad(mlx_tests.MLXTestCase):
|
||||
aload = mx.load(save_file)["a"]
|
||||
self.assertTrue(mx.array_equal(a, aload))
|
||||
|
||||
def test_load_donation(self):
|
||||
x = mx.random.normal((1024,))
|
||||
mx.eval(x)
|
||||
save_file = os.path.join(self.test_dir, "donation.npy")
|
||||
mx.save(save_file, x)
|
||||
|
||||
mx.metal.reset_peak_memory()
|
||||
scale = mx.array(2.0)
|
||||
y = mx.load(save_file)
|
||||
mx.eval(y)
|
||||
load_only = mx.metal.get_peak_memory()
|
||||
y = mx.load(save_file) * scale
|
||||
mx.eval(y)
|
||||
load_with_binary = mx.metal.get_peak_memory()
|
||||
|
||||
self.assertEqual(load_only, load_with_binary)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user