Fix bug in tape order during simplify (#816)

* fix bug in tape order during simplify

* properly fix compile

* last bug
This commit is contained in:
Awni Hannun
2024-03-11 17:29:05 -07:00
committed by GitHub
parent 0ae22b915b
commit 0e95b64942
2 changed files with 55 additions and 4 deletions

View File

@@ -613,6 +613,46 @@ class TestCompile(mlx_tests.MLXTestCase):
out = fun()
mx.eval(out)
def test_compile_vjp(self):
def fun(w):
w1 = w + w
w2 = w + w
return w @ w1 + w2 @ w2
def step(w):
out, grad = mx.vjp(fun, (w,), (mx.array([[1.0, 1.0], [1.0, 1.0]]),))
return out[0], grad[0]
w = mx.zeros((2, 2))
mx.eval(w)
expected = step(w)
out = mx.compile(step)(w)
self.assertTrue(mx.allclose(expected[0], out[0]))
self.assertTrue(mx.allclose(expected[1], out[1]))
def fun(w1, w2, x):
x = x @ w1
y = x @ w2
x = x + y * y
return (x * x).sum()
w1 = mx.zeros((4, 4))
w2 = mx.zeros((4, 4))
x = mx.zeros((4, 4))
def step(w1, w2, x):
loss, gradient = mx.value_and_grad(fun)(w1, w2, x)
w1 = w1 + gradient
return loss, w1
mx.eval(x, w1, w2)
expected = step(w1, w2, x)
out = mx.compile(step)(w1, w2, x)
self.assertTrue(mx.allclose(expected[0], out[0]))
self.assertTrue(mx.allclose(expected[1], out[1]))
if __name__ == "__main__":
unittest.main()