mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-02 17:28:12 +08:00
move memory APIs into top level mlx.core (#1982)
This commit is contained in:
@@ -118,9 +118,9 @@ class TestEval(mlx_tests.MLXTestCase):
|
||||
out = mx.vjp(fn, (x,), (y,))
|
||||
out = mx.vjp(fn, (x,), (y,))
|
||||
if mx.metal.is_available():
|
||||
peak_mem = mx.metal.get_peak_memory()
|
||||
peak_mem = mx.get_peak_memory()
|
||||
out = mx.vjp(fn, (x,), (y,))
|
||||
self.assertEqual(peak_mem, mx.metal.get_peak_memory())
|
||||
self.assertEqual(peak_mem, mx.get_peak_memory())
|
||||
|
||||
def test_async_eval_with_multiple_streams(self):
|
||||
x = mx.array([1.0])
|
||||
@@ -151,11 +151,11 @@ class TestEval(mlx_tests.MLXTestCase):
|
||||
|
||||
x = mx.zeros((4096, 4096))
|
||||
mx.eval(x)
|
||||
pre = mx.metal.get_peak_memory()
|
||||
pre = mx.get_peak_memory()
|
||||
out = fun(x)
|
||||
del x
|
||||
mx.eval(out)
|
||||
post = mx.metal.get_peak_memory()
|
||||
post = mx.get_peak_memory()
|
||||
self.assertEqual(pre, post)
|
||||
|
||||
def fun(x):
|
||||
@@ -167,11 +167,11 @@ class TestEval(mlx_tests.MLXTestCase):
|
||||
|
||||
x = mx.zeros((4096 * 4096,))
|
||||
mx.eval(x)
|
||||
pre = mx.metal.get_peak_memory()
|
||||
pre = mx.get_peak_memory()
|
||||
out = fun(x)
|
||||
del x
|
||||
mx.eval(out)
|
||||
post = mx.metal.get_peak_memory()
|
||||
post = mx.get_peak_memory()
|
||||
self.assertEqual(pre, post)
|
||||
|
||||
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||
@@ -187,7 +187,7 @@ class TestEval(mlx_tests.MLXTestCase):
|
||||
|
||||
s1 = mx.default_stream(mx.gpu)
|
||||
s2 = mx.new_stream(mx.gpu)
|
||||
old_limit = mx.metal.set_memory_limit(1000)
|
||||
old_limit = mx.set_memory_limit(1000)
|
||||
|
||||
x = mx.ones((512, 512), stream=s2)
|
||||
for _ in range(80):
|
||||
@@ -195,7 +195,7 @@ class TestEval(mlx_tests.MLXTestCase):
|
||||
y = mx.abs(x, stream=s2)
|
||||
z = mx.abs(y, stream=s2)
|
||||
mx.eval(z)
|
||||
mx.metal.set_memory_limit(old_limit)
|
||||
mx.set_memory_limit(old_limit)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user