mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-10 21:37:50 +08:00
fix multi output leak (#1548)
This commit is contained in:
@@ -1771,6 +1771,19 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
peak_2 = mx.metal.get_peak_memory()
|
||||
self.assertEqual(peak_1, peak_2)
|
||||
|
||||
def fun():
|
||||
a = mx.array([1.0, 2.0, 3.0, 4.0])
|
||||
b, _ = mx.divmod(a, a)
|
||||
return mx.log(b)
|
||||
|
||||
fun()
|
||||
mx.synchronize()
|
||||
peak_1 = mx.metal.get_peak_memory()
|
||||
fun()
|
||||
mx.synchronize()
|
||||
peak_2 = mx.metal.get_peak_memory()
|
||||
self.assertEqual(peak_1, peak_2)
|
||||
|
||||
def test_add_numpy(self):
|
||||
x = mx.array(1)
|
||||
y = np.array(2, dtype=np.int32)
|
||||
|
Reference in New Issue
Block a user