mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Fix bug in tape order during simplify (#816)
* fix bug in tape order during simplify * properly fix compile * last bug
This commit is contained in:
@@ -613,6 +613,46 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
out = fun()
|
||||
mx.eval(out)
|
||||
|
||||
def test_compile_vjp(self):
|
||||
def fun(w):
|
||||
w1 = w + w
|
||||
w2 = w + w
|
||||
return w @ w1 + w2 @ w2
|
||||
|
||||
def step(w):
|
||||
out, grad = mx.vjp(fun, (w,), (mx.array([[1.0, 1.0], [1.0, 1.0]]),))
|
||||
return out[0], grad[0]
|
||||
|
||||
w = mx.zeros((2, 2))
|
||||
mx.eval(w)
|
||||
|
||||
expected = step(w)
|
||||
out = mx.compile(step)(w)
|
||||
self.assertTrue(mx.allclose(expected[0], out[0]))
|
||||
self.assertTrue(mx.allclose(expected[1], out[1]))
|
||||
|
||||
def fun(w1, w2, x):
|
||||
x = x @ w1
|
||||
y = x @ w2
|
||||
x = x + y * y
|
||||
return (x * x).sum()
|
||||
|
||||
w1 = mx.zeros((4, 4))
|
||||
w2 = mx.zeros((4, 4))
|
||||
x = mx.zeros((4, 4))
|
||||
|
||||
def step(w1, w2, x):
|
||||
loss, gradient = mx.value_and_grad(fun)(w1, w2, x)
|
||||
w1 = w1 + gradient
|
||||
return loss, w1
|
||||
|
||||
mx.eval(x, w1, w2)
|
||||
expected = step(w1, w2, x)
|
||||
out = mx.compile(step)(w1, w2, x)
|
||||
|
||||
self.assertTrue(mx.allclose(expected[0], out[0]))
|
||||
self.assertTrue(mx.allclose(expected[1], out[1]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user