fix compile copying (#2871)

This commit is contained in:
Awni Hannun
2025-12-04 12:32:56 -08:00
committed by GitHub
parent 941cfe23d7
commit a6d6717181
2 changed files with 39 additions and 1 deletions

View File

@@ -1223,6 +1223,35 @@ class TestCompile(mlx_tests.MLXTestCase):
self.assertTrue(mx.array_equal(out3[0], mx.array([11, 12, 13])))
self.assertTrue(mx.array_equal(out3[1], mx.array([40, 50, 60])))
def test_compile_output_with_siblings(self):
@mx.compile
def fun(x, y):
return mx.divmod(mx.abs(x), mx.abs(y))[0]
out = fun(mx.array(1.0), mx.array(1.0))
self.assertEqual(out.item(), 1.0)
# Make sure the following compiles without issue
def loss_fn(params, x):
emb, w = params
return mx.fast.layer_norm(emb[x], w, None, 1e-4).sum()
emb = mx.zeros((10, 32))
w = mx.zeros((32,))
loss_and_grad_fn = mx.value_and_grad(loss_fn)
x = mx.zeros(shape=(4, 32), dtype=mx.int32)
mx.eval(x, emb, w)
@mx.compile
def step(emb, w, x):
loss, grads = loss_and_grad_fn((emb, w), x)
return loss, grads
loss, grads = step(emb, w, x)
mx.eval(loss, grads)
if __name__ == "__main__":
mlx_tests.MLXTestRunner()