Custom transforms (#1246)

This commit is contained in:
Angelos Katharopoulos
2024-07-10 18:00:01 -07:00
committed by GitHub
parent a3c287354f
commit 5c1fa64fb0
16 changed files with 734 additions and 39 deletions

View File

@@ -496,6 +496,90 @@ class TestAutograd(mlx_tests.MLXTestCase):
expected = mx.array([0.0, 0.0, 0.0, 9.0, 1.0])
self.assertTrue(mx.allclose(out, expected))
def test_custom_function(self):
# Make a custom function
my_exp = mx.custom_function(mx.exp)
# Ensure everything works
dy = mx.grad(my_exp)(mx.array(1.0))
self.assertTrue(mx.allclose(dy, mx.exp(mx.array(1.0))))
(ex,), (dex,) = mx.jvp(my_exp, [mx.array(1.0)], [mx.array(1.0)])
self.assertTrue(mx.allclose(dex, mx.exp(mx.array(1.0))))
self.assertTrue(mx.allclose(ex, dex))
ex = mx.vmap(my_exp)(mx.ones(10))
self.assertTrue(mx.allclose(ex, mx.exp(mx.ones(10))))
# Ensure that the vjp is being overriden but everything else still
# works.
@my_exp.vjp
def my_exp_vjp(x, dx, ex):
return mx.ones_like(x) * 42
dy = mx.grad(my_exp)(mx.array(1.0))
self.assertTrue(mx.allclose(dy, mx.array(42.0)))
(ex,), (dex,) = mx.jvp(my_exp, [mx.array(1.0)], [mx.array(1.0)])
self.assertTrue(mx.allclose(dex, mx.exp(mx.array(1.0))))
self.assertTrue(mx.allclose(ex, dex))
ex = mx.vmap(my_exp)(mx.ones(10))
self.assertTrue(mx.allclose(ex, mx.exp(mx.ones(10))))
# Ensure that setting the jvp and vmap also works.
@my_exp.jvp
def my_exp_jvp(x, dx):
return mx.ones_like(x) * 7 * dx
@my_exp.vmap
def my_exp_vmap(x, axis):
return mx.ones_like(x) * 3, axis
dy = mx.grad(my_exp)(mx.array(1.0))
self.assertTrue(mx.allclose(dy, mx.array(42.0)))
(ex,), (dex,) = mx.jvp(my_exp, [mx.array(1.0)], [mx.array(1.0)])
self.assertTrue(mx.allclose(dex, mx.array(7.0)))
self.assertTrue(mx.allclose(ex, mx.exp(mx.array(1.0))))
ex = mx.vmap(my_exp)(mx.ones(10))
self.assertTrue(mx.allclose(ex, 3 * mx.ones(10)))
# Test pytrees
@mx.custom_function
def my_double(params):
return {"out": 2 * params["x"] * params["y"]}
dy = mx.grad(lambda p: my_double(p)["out"].sum())(
{"x": mx.ones(2), "y": mx.ones(2)}
)
self.assertTrue(mx.allclose(dy["x"], mx.ones(2) * 2))
self.assertTrue(mx.allclose(dy["y"], mx.ones(2) * 2))
@my_double.vjp
def random_grads(primals, cotangents, outputs):
return {"x": mx.zeros_like(primals["x"]), "y": mx.ones_like(primals["y"])}
dy = mx.grad(lambda p: my_double(p)["out"].sum())(
{"x": mx.ones(2), "y": mx.ones(2)}
)
self.assertTrue(mx.allclose(dy["x"], mx.zeros(2)))
self.assertTrue(mx.allclose(dy["y"], mx.ones(2)))
def outer_f(a, b):
return my_double({"x": a, "y": b})["out"]
inputs = [mx.random.normal(shape=(2,)) for i in range(2)]
tans = [mx.random.normal(shape=(2,)) for i in range(2)]
out1, dout1 = mx.jvp(outer_f, inputs, tans)
@my_double.jvp
def random_grads(primals, tangents):
return {
"out": 2 * primals["x"] * tangents["y"]
+ 2 * primals["y"] * tangents["x"]
+ 1
}
out2, dout2 = mx.jvp(outer_f, inputs, tans)
self.assertTrue(mx.allclose(out1[0], out2[0]))
self.assertTrue(mx.allclose(dout1[0] + 1, dout2[0]))
if __name__ == "__main__":
unittest.main()