mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
fix compile copying (#2871)
This commit is contained in:
@@ -1223,6 +1223,35 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(mx.array_equal(out3[0], mx.array([11, 12, 13])))
|
||||
self.assertTrue(mx.array_equal(out3[1], mx.array([40, 50, 60])))
|
||||
|
||||
def test_compile_output_with_siblings(self):
|
||||
@mx.compile
|
||||
def fun(x, y):
|
||||
return mx.divmod(mx.abs(x), mx.abs(y))[0]
|
||||
|
||||
out = fun(mx.array(1.0), mx.array(1.0))
|
||||
self.assertEqual(out.item(), 1.0)
|
||||
|
||||
# Make sure the following compiles without issue
|
||||
def loss_fn(params, x):
|
||||
emb, w = params
|
||||
return mx.fast.layer_norm(emb[x], w, None, 1e-4).sum()
|
||||
|
||||
emb = mx.zeros((10, 32))
|
||||
w = mx.zeros((32,))
|
||||
|
||||
loss_and_grad_fn = mx.value_and_grad(loss_fn)
|
||||
|
||||
x = mx.zeros(shape=(4, 32), dtype=mx.int32)
|
||||
mx.eval(x, emb, w)
|
||||
|
||||
@mx.compile
|
||||
def step(emb, w, x):
|
||||
loss, grads = loss_and_grad_fn((emb, w), x)
|
||||
return loss, grads
|
||||
|
||||
loss, grads = step(emb, w, x)
|
||||
mx.eval(loss, grads)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
||||
Reference in New Issue
Block a user