mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-07 00:54:37 +08:00
Custom transforms (#1246)
This commit is contained in:

committed by
GitHub

parent
a3c287354f
commit
5c1fa64fb0
@@ -496,6 +496,90 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
expected = mx.array([0.0, 0.0, 0.0, 9.0, 1.0])
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
def test_custom_function(self):
|
||||
# Make a custom function
|
||||
my_exp = mx.custom_function(mx.exp)
|
||||
|
||||
# Ensure everything works
|
||||
dy = mx.grad(my_exp)(mx.array(1.0))
|
||||
self.assertTrue(mx.allclose(dy, mx.exp(mx.array(1.0))))
|
||||
(ex,), (dex,) = mx.jvp(my_exp, [mx.array(1.0)], [mx.array(1.0)])
|
||||
self.assertTrue(mx.allclose(dex, mx.exp(mx.array(1.0))))
|
||||
self.assertTrue(mx.allclose(ex, dex))
|
||||
ex = mx.vmap(my_exp)(mx.ones(10))
|
||||
self.assertTrue(mx.allclose(ex, mx.exp(mx.ones(10))))
|
||||
|
||||
# Ensure that the vjp is being overriden but everything else still
|
||||
# works.
|
||||
@my_exp.vjp
|
||||
def my_exp_vjp(x, dx, ex):
|
||||
return mx.ones_like(x) * 42
|
||||
|
||||
dy = mx.grad(my_exp)(mx.array(1.0))
|
||||
self.assertTrue(mx.allclose(dy, mx.array(42.0)))
|
||||
(ex,), (dex,) = mx.jvp(my_exp, [mx.array(1.0)], [mx.array(1.0)])
|
||||
self.assertTrue(mx.allclose(dex, mx.exp(mx.array(1.0))))
|
||||
self.assertTrue(mx.allclose(ex, dex))
|
||||
ex = mx.vmap(my_exp)(mx.ones(10))
|
||||
self.assertTrue(mx.allclose(ex, mx.exp(mx.ones(10))))
|
||||
|
||||
# Ensure that setting the jvp and vmap also works.
|
||||
@my_exp.jvp
|
||||
def my_exp_jvp(x, dx):
|
||||
return mx.ones_like(x) * 7 * dx
|
||||
|
||||
@my_exp.vmap
|
||||
def my_exp_vmap(x, axis):
|
||||
return mx.ones_like(x) * 3, axis
|
||||
|
||||
dy = mx.grad(my_exp)(mx.array(1.0))
|
||||
self.assertTrue(mx.allclose(dy, mx.array(42.0)))
|
||||
(ex,), (dex,) = mx.jvp(my_exp, [mx.array(1.0)], [mx.array(1.0)])
|
||||
self.assertTrue(mx.allclose(dex, mx.array(7.0)))
|
||||
self.assertTrue(mx.allclose(ex, mx.exp(mx.array(1.0))))
|
||||
ex = mx.vmap(my_exp)(mx.ones(10))
|
||||
self.assertTrue(mx.allclose(ex, 3 * mx.ones(10)))
|
||||
|
||||
# Test pytrees
|
||||
@mx.custom_function
|
||||
def my_double(params):
|
||||
return {"out": 2 * params["x"] * params["y"]}
|
||||
|
||||
dy = mx.grad(lambda p: my_double(p)["out"].sum())(
|
||||
{"x": mx.ones(2), "y": mx.ones(2)}
|
||||
)
|
||||
self.assertTrue(mx.allclose(dy["x"], mx.ones(2) * 2))
|
||||
self.assertTrue(mx.allclose(dy["y"], mx.ones(2) * 2))
|
||||
|
||||
@my_double.vjp
|
||||
def random_grads(primals, cotangents, outputs):
|
||||
return {"x": mx.zeros_like(primals["x"]), "y": mx.ones_like(primals["y"])}
|
||||
|
||||
dy = mx.grad(lambda p: my_double(p)["out"].sum())(
|
||||
{"x": mx.ones(2), "y": mx.ones(2)}
|
||||
)
|
||||
self.assertTrue(mx.allclose(dy["x"], mx.zeros(2)))
|
||||
self.assertTrue(mx.allclose(dy["y"], mx.ones(2)))
|
||||
|
||||
def outer_f(a, b):
|
||||
return my_double({"x": a, "y": b})["out"]
|
||||
|
||||
inputs = [mx.random.normal(shape=(2,)) for i in range(2)]
|
||||
tans = [mx.random.normal(shape=(2,)) for i in range(2)]
|
||||
out1, dout1 = mx.jvp(outer_f, inputs, tans)
|
||||
|
||||
@my_double.jvp
|
||||
def random_grads(primals, tangents):
|
||||
return {
|
||||
"out": 2 * primals["x"] * tangents["y"]
|
||||
+ 2 * primals["y"] * tangents["x"]
|
||||
+ 1
|
||||
}
|
||||
|
||||
out2, dout2 = mx.jvp(outer_f, inputs, tans)
|
||||
self.assertTrue(mx.allclose(out1[0], out2[0]))
|
||||
self.assertTrue(mx.allclose(dout1[0] + 1, dout2[0]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user