mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-06 16:51:24 +08:00
@@ -2,6 +2,7 @@
|
||||
|
||||
import operator
|
||||
import pickle
|
||||
import resource
|
||||
import sys
|
||||
import unittest
|
||||
import weakref
|
||||
@@ -1928,6 +1929,25 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
x = mx.sin(x)
|
||||
mx.eval(x)
|
||||
|
||||
def test_siblings_without_eval(self):
|
||||
def get_mem():
|
||||
return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
|
||||
|
||||
key = mx.array([1, 2])
|
||||
|
||||
def t():
|
||||
a, b = mx.split(key, 2)
|
||||
a = mx.reshape(a, [])
|
||||
b = mx.reshape(b, [])
|
||||
return b
|
||||
|
||||
t()
|
||||
expected = get_mem()
|
||||
for _ in range(100):
|
||||
t()
|
||||
used = get_mem()
|
||||
self.assertEqual(expected, used)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user