mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Limit compile buffers (#1887)
* limit compile buffers * maybe not flaky test
This commit is contained in:
@@ -177,6 +177,7 @@ class TestDistributed(mlx_tests.MLXTestCase):
|
||||
def test_donation(self):
|
||||
x = mx.random.normal((1024,))
|
||||
mx.eval(x)
|
||||
mx.synchronize(mx.default_stream(mx.default_device()))
|
||||
|
||||
mx.metal.reset_peak_memory()
|
||||
scale = mx.array(2.0)
|
||||
|
@@ -815,6 +815,31 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
out = fun(*inputs)
|
||||
self.assertTrue(mx.allclose(out, mx.full((2, 2), 20)))
|
||||
|
||||
@mx.compile
|
||||
def fun(arrs):
|
||||
for _ in range(6):
|
||||
arrs = [x + y for x, y in zip(arrs[::2], arrs[1::2])]
|
||||
return arrs[0]
|
||||
|
||||
arrs = [mx.array([1.0, 2.0]) for _ in range(64)]
|
||||
out = fun(arrs)
|
||||
self.assertTrue(mx.allclose(out, mx.array([64.0, 128.0])))
|
||||
|
||||
def test_compile_many_outputs(self):
|
||||
|
||||
@mx.compile
|
||||
def fun(arr):
|
||||
arrs = [arr] * 64
|
||||
first_arrs = None
|
||||
for _ in range(6):
|
||||
arrs = [x + y for x, y in zip(arrs[::2], arrs[1::2])]
|
||||
if first_arrs is None:
|
||||
first_arrs = arrs
|
||||
return arrs[0], first_arrs
|
||||
|
||||
out = fun(mx.array([1.0, 2.0]))
|
||||
self.assertTrue(mx.allclose(out[0], mx.array([64.0, 128.0])))
|
||||
|
||||
def test_shapeless_compile_matmul(self):
|
||||
a = mx.array([0.0, 1.0, 2.0])
|
||||
b = mx.array([0.0, 1.0, 2.0])
|
||||
|
@@ -385,6 +385,7 @@ class TestLoad(mlx_tests.MLXTestCase):
|
||||
mx.eval(x)
|
||||
save_file = os.path.join(self.test_dir, "donation.npy")
|
||||
mx.save(save_file, x)
|
||||
mx.synchronize(mx.default_stream(mx.default_device()))
|
||||
|
||||
mx.metal.reset_peak_memory()
|
||||
scale = mx.array(2.0)
|
||||
|
Reference in New Issue
Block a user