2023-12-01 03:12:53 +08:00
|
|
|
# Copyright © 2023 Apple Inc.
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
import unittest
|
|
|
|
|
|
|
|
import mlx.core as mx
|
|
|
|
import mlx_tests
|
|
|
|
|
|
|
|
|
|
|
|
class TestAutograd(mlx_tests.MLXTestCase):
|
|
|
|
def test_jvp(self):
|
|
|
|
fun = lambda x: 2 * x
|
|
|
|
out, dout = mx.jvp(fun, [mx.array(1.0)], [mx.array(2.0)])
|
|
|
|
self.assertEqual(out[0].item(), 2.0)
|
|
|
|
self.assertEqual(dout[0].item(), 4.0)
|
|
|
|
|
|
|
|
fun = lambda x, y: x * y
|
|
|
|
_, out = mx.jvp(
|
|
|
|
fun, [mx.array(4.0), mx.array(2.0)], [mx.array(3.0), mx.array(2.0)]
|
|
|
|
)
|
|
|
|
self.assertEqual(out[0].item(), 4.0 * 2.0 + 2.0 * 3.0)
|
|
|
|
|
|
|
|
fun = lambda x, y, z: (x * y, y * z)
|
|
|
|
_, out = mx.jvp(
|
|
|
|
fun,
|
|
|
|
[mx.array(2.0), mx.array(4.0), mx.array(6.0)],
|
|
|
|
[mx.array(1.0), mx.array(3.0), mx.array(1.0)],
|
|
|
|
)
|
|
|
|
self.assertEqual(len(out), 2)
|
|
|
|
self.assertEqual(out[0].item(), 4.0 * 1.0 + 2.0 * 3.0)
|
|
|
|
self.assertEqual(out[1].item(), 4.0 * 1.0 + 6.0 * 3.0)
|
|
|
|
|
|
|
|
def test_vjp(self):
|
|
|
|
fun = lambda x: 2 * x
|
|
|
|
out, dout = mx.vjp(fun, [mx.array(1.0)], [mx.array(2.0)])
|
|
|
|
self.assertEqual(out[0].item(), 2.0)
|
|
|
|
self.assertEqual(dout[0].item(), 4.0)
|
|
|
|
|
|
|
|
fun = lambda x, y: x * y
|
|
|
|
_, dout = mx.vjp(fun, [mx.array(4.0), mx.array(2.0)], [mx.array(3.0)])
|
|
|
|
self.assertEqual(dout[0].item(), 6.0)
|
|
|
|
self.assertEqual(dout[1].item(), 12.0)
|
|
|
|
|
|
|
|
fun = lambda x, y, z: (x * y, y * z)
|
|
|
|
_, out = mx.vjp(
|
|
|
|
fun,
|
|
|
|
[mx.array(2.0), mx.array(4.0), mx.array(6.0)],
|
|
|
|
[mx.array(1.0), mx.array(3.0)],
|
|
|
|
)
|
|
|
|
self.assertEqual(len(out), 3)
|
|
|
|
self.assertEqual(out[0].item(), 4.0 * 1.0)
|
|
|
|
self.assertEqual(out[1].item(), 2.0 * 1.0 + 6.0 * 3.0)
|
|
|
|
self.assertEqual(out[2].item(), 4.0 * 3.0)
|
|
|
|
|
|
|
|
def test_grad(self):
|
|
|
|
fun = lambda x: x * x
|
|
|
|
|
|
|
|
value, dfdx = mx.value_and_grad(fun)(mx.array(0.5))
|
|
|
|
self.assertEqual(value.item(), 0.25)
|
|
|
|
self.assertEqual(dfdx.item(), 1.0)
|
|
|
|
|
|
|
|
dfdx = mx.grad(fun)(mx.array(0.5))
|
|
|
|
self.assertEqual(dfdx.item(), 1.0)
|
|
|
|
|
|
|
|
df2dx2 = mx.grad(mx.grad(fun))(mx.array(0.5))
|
|
|
|
self.assertEqual(df2dx2.item(), 2.0)
|
|
|
|
df3dx3 = mx.grad(mx.grad(mx.grad(fun)))(mx.array(0.5))
|
|
|
|
self.assertEqual(df3dx3.item(), 0.0)
|
|
|
|
|
|
|
|
fun = lambda x, y: x * y
|
|
|
|
x = mx.array(2.0)
|
|
|
|
y = mx.array(3.0)
|
|
|
|
dfdx = mx.grad(fun, argnums=0)(x, y)
|
|
|
|
self.assertEqual(dfdx.item(), 3.0)
|
|
|
|
dfdx = mx.grad(fun, argnums=1)(x, y)
|
|
|
|
self.assertEqual(dfdx.item(), 2.0)
|
|
|
|
|
|
|
|
# Pass non array args to functions works
|
|
|
|
fun = lambda x, y: x
|
|
|
|
value, dfdx = mx.value_and_grad(fun)(mx.array(2.0), "hello")
|
|
|
|
self.assertEqual(value.item(), 2.0)
|
|
|
|
self.assertEqual(dfdx.item(), 1.0)
|
|
|
|
|
|
|
|
dfdx = mx.grad(fun)(mx.array(2.0), "hello")
|
|
|
|
self.assertEqual(dfdx.item(), 1.0)
|
|
|
|
|
|
|
|
# Raises when function does not return array
|
|
|
|
fun = lambda x: "hello"
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
mx.grad(fun)(mx.array(2.0))
|
|
|
|
|
|
|
|
# Raises for invalid argument number or argument type
|
|
|
|
fun = lambda x: x
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
mx.grad(fun, argnums=2)(mx.array(2.0))
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
mx.grad(fun, argnums=-2)(mx.array(2.0))
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
mx.grad(fun)("hello")
|
|
|
|
|
|
|
|
# Raises when output is not a scalar array
|
|
|
|
fun = lambda x: mx.sum(x, keepdims=True)
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
mx.grad(fun)(mx.ones((2, 2)))
|
|
|
|
|
|
|
|
def test_grad_trees(self):
|
|
|
|
fun = lambda x, y: x * y
|
|
|
|
value, dfdx = mx.value_and_grad(fun, (0, 1))(mx.array(0.5), mx.array(2.0))
|
|
|
|
self.assertEqual(value.item(), 1.0)
|
|
|
|
self.assertTrue(isinstance(dfdx, tuple))
|
|
|
|
self.assertEqual(dfdx[0].item(), 2.0)
|
|
|
|
self.assertEqual(dfdx[1].item(), 0.5)
|
|
|
|
|
|
|
|
fun = lambda x, y: x * y
|
|
|
|
value, dfdx = mx.value_and_grad(fun, 1)(mx.array(0.5), mx.array(2.0))
|
|
|
|
self.assertEqual(value.item(), 1.0)
|
|
|
|
self.assertEqual(dfdx.item(), 0.5)
|
|
|
|
|
|
|
|
fun = lambda p: p["x"] * p["y"]
|
|
|
|
value, dfdx = mx.value_and_grad(fun)({"x": mx.array(0.5), "y": mx.array(2.0)})
|
|
|
|
self.assertEqual(value.item(), 1.0)
|
|
|
|
self.assertEqual(dfdx["x"].item(), 2.0)
|
|
|
|
self.assertEqual(dfdx["y"].item(), 0.5)
|
|
|
|
|
|
|
|
fun = lambda p: p["x"] * p["y"]
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
mx.value_and_grad(fun)({"x": 0.5, "y": mx.array(2.0)})
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
mx.value_and_grad(fun, (0, 1))({"x": mx.array(0.5), "y": mx.array(2.0)})
|
|
|
|
|
|
|
|
fun = lambda p, b: mx.square(p[0]["foo"][2]) * b
|
|
|
|
value, dfdx = mx.value_and_grad(fun)(
|
|
|
|
[{"foo": [[], [], mx.array(2.0)]}], mx.array(0.5)
|
|
|
|
)
|
|
|
|
self.assertEqual(value.item(), 2.0)
|
|
|
|
self.assertEqual(dfdx[0]["foo"][2].item(), 2.0)
|
|
|
|
|
|
|
|
fun = lambda x: x
|
|
|
|
with self.assertRaises(TypeError):
|
|
|
|
mx.value_and_grad(fun, (None, None))
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
mx.value_and_grad(fun, tuple())
|
|
|
|
|
|
|
|
def test_auxiliary_values(self):
|
|
|
|
def fun(x, y):
|
|
|
|
l = (x * y).sum()
|
|
|
|
extra = {"loss": l, "foo": y.square() + x.square(), "bar": [1, 2, 3, y, x]}
|
|
|
|
return l, extra
|
|
|
|
|
|
|
|
fun_value_grad = mx.value_and_grad(fun)
|
|
|
|
fun_grad = mx.grad(fun)
|
|
|
|
|
|
|
|
(loss, a), b = fun_value_grad(mx.ones((2, 2)), mx.ones((2, 2)))
|
|
|
|
self.assertEqual(a["loss"].item(), 4)
|
|
|
|
self.assertTrue(mx.array_equal(b, mx.ones((2, 2))))
|
|
|
|
self.assertTrue(mx.array_equal(a["foo"], 2 * mx.ones((2, 2))))
|
|
|
|
self.assertEqual(a["bar"][:3], [1, 2, 3])
|
|
|
|
self.assertTrue(mx.array_equal(a["bar"][3], mx.ones((2, 2))))
|
|
|
|
self.assertTrue(mx.array_equal(a["bar"][4], mx.ones((2, 2))))
|
|
|
|
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
_ = fun_grad(mx.ones((2, 2)), mx.ones((2, 2)))
|
|
|
|
|
|
|
|
def test_grad_kwargs(self):
|
|
|
|
fun = lambda x, y: x * y
|
|
|
|
a, b = mx.array(0.5), mx.array(2.0)
|
|
|
|
dfdx = mx.grad(fun)
|
|
|
|
self.assertEqual(dfdx(a, b).item(), 2.0)
|
|
|
|
self.assertEqual(dfdx(a, y=b).item(), 2.0)
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
dfdx(x=a, y=b).item()
|
|
|
|
|
|
|
|
dfdy = mx.grad(fun, argnums=[], argnames=["y"])
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
dfdy(a, b)
|
|
|
|
grads = dfdy(a, y=b)
|
|
|
|
self.assertTrue(isinstance(grads, tuple))
|
|
|
|
self.assertTrue(grads[0] is None)
|
|
|
|
self.assertTrue(isinstance(grads[1], dict))
|
|
|
|
self.assertEqual(grads[1]["y"].item(), 0.5)
|
|
|
|
grads = dfdy(x=a, y=b)
|
|
|
|
self.assertEqual(grads[1]["y"].item(), 0.5)
|
|
|
|
self.assertEqual(len(grads[1]), 1)
|
|
|
|
|
|
|
|
dfdxy = mx.grad(fun, argnums=[0], argnames=["y"])
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
dfdxy(a, b)
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
dfdxy(x=a, y=b)
|
|
|
|
grads = dfdxy(a, y=b)
|
|
|
|
self.assertTrue(isinstance(grads, tuple))
|
|
|
|
self.assertEqual(grads[0].item(), 2.0)
|
|
|
|
self.assertTrue(isinstance(grads[1], dict))
|
|
|
|
self.assertEqual(grads[1]["y"].item(), 0.5)
|
|
|
|
|
|
|
|
fun = lambda x, y, z: x * y * z
|
|
|
|
dfdxyz = mx.grad(fun, argnums=[0, 1], argnames=["z"])
|
|
|
|
c = mx.array(4.0)
|
|
|
|
grads = dfdxyz(a, b, z=c)
|
|
|
|
self.assertTrue(isinstance(grads, tuple))
|
|
|
|
self.assertTrue(isinstance(grads[0], tuple))
|
|
|
|
self.assertEqual(grads[0][0].item(), 8.0)
|
|
|
|
self.assertEqual(grads[0][1].item(), 2.0)
|
|
|
|
self.assertTrue(isinstance(grads[1], dict))
|
|
|
|
self.assertEqual(grads[1]["z"].item(), 1.0)
|
|
|
|
|
|
|
|
fun = lambda x, y: x * y
|
|
|
|
dfdy = mx.grad(fun, argnames=["y"])
|
|
|
|
grads = dfdy(a, y=b)
|
|
|
|
self.assertTrue(isinstance(grads, tuple))
|
|
|
|
self.assertTrue(grads[0] is None)
|
|
|
|
self.assertTrue(isinstance(grads[1], dict))
|
|
|
|
self.assertEqual(grads[1]["y"].item(), 0.5)
|
|
|
|
|
|
|
|
def test_captured(self):
|
|
|
|
a = mx.array(5.0)
|
|
|
|
f = lambda x: a + x
|
|
|
|
g = lambda x: a + a
|
|
|
|
h = lambda x: x + x
|
|
|
|
|
|
|
|
dfdx = mx.grad(f)
|
|
|
|
self.assertEqual(dfdx(a).item(), 1.0)
|
|
|
|
|
|
|
|
dgdx = mx.grad(g)
|
|
|
|
self.assertEqual(dgdx(a).item(), 0.0)
|
|
|
|
|
|
|
|
dhdx = mx.grad(h)
|
|
|
|
self.assertEqual(dhdx(a).item(), 2.0)
|
|
|
|
|
|
|
|
d2fdx2 = mx.grad(dfdx)
|
|
|
|
self.assertEqual(d2fdx2(a).item(), 0.0)
|
|
|
|
|
|
|
|
d2gdx2 = mx.grad(dgdx)
|
|
|
|
self.assertEqual(d2gdx2(a).item(), 0.0)
|
|
|
|
|
|
|
|
d2hdx2 = mx.grad(dhdx)
|
|
|
|
self.assertEqual(d2hdx2(a).item(), 0.0)
|
|
|
|
|
|
|
|
def test_stop_gradient(self):
|
|
|
|
shape_in = (4, 4)
|
|
|
|
w_in = mx.ones(shape_in)
|
|
|
|
x_in = mx.ones(shape_in)
|
|
|
|
cotan = mx.ones(shape_in)
|
|
|
|
|
|
|
|
def h(w, x):
|
|
|
|
x1 = 2 * x
|
|
|
|
y = mx.stop_gradient(x1)
|
|
|
|
y1 = 3 * y
|
|
|
|
return w @ y1
|
|
|
|
|
|
|
|
vals, vjps = mx.vjp(h, [w_in, x_in], [cotan])
|
|
|
|
mx.eval(vjps)
|
|
|
|
|
|
|
|
self.assertTrue(mx.allclose(vjps[0], 24.0 * mx.ones(shape_in)))
|
|
|
|
self.assertTrue(mx.allclose(vjps[1], mx.zeros(shape_in)))
|
|
|
|
|
|
|
|
g = lambda x: h(w_in, x)
|
|
|
|
vals, vjps = mx.vjp(g, [x_in], [cotan])
|
|
|
|
mx.eval(vjps)
|
|
|
|
|
|
|
|
self.assertTrue(mx.allclose(vjps[0], mx.zeros(shape_in)))
|
|
|
|
|
2024-01-08 07:16:51 +08:00
|
|
|
def test_update_state(self):
|
|
|
|
y = mx.array([1.0])
|
|
|
|
state = mx.zeros((2,))
|
|
|
|
|
|
|
|
def fn(y, x):
|
|
|
|
nonlocal state
|
|
|
|
x = y * x
|
|
|
|
state = state + x
|
|
|
|
return x.sum()
|
|
|
|
|
|
|
|
x = mx.ones((2,))
|
|
|
|
mx.grad(fn)(y, x)
|
|
|
|
mx.eval(state)
|
|
|
|
self.assertTrue(mx.allclose(state, mx.ones((2,))))
|
|
|
|
|
2024-01-11 05:32:37 +08:00
|
|
|
def test_scatter_vjp(self):
|
|
|
|
def fun(x, idx):
|
|
|
|
x[idx] = 2.0
|
|
|
|
return x.sum()
|
|
|
|
|
|
|
|
dfdx = mx.grad(fun)(mx.array([1.0, 2.0, 3.0]), mx.array([1]))
|
|
|
|
self.assertTrue(mx.array_equal(dfdx, mx.array([1.0, 0.0, 1.0])))
|
|
|
|
self.assertEqual(dfdx.dtype, mx.float32)
|
|
|
|
|
|
|
|
y = mx.array([0.0, 1.0, 2.0])
|
|
|
|
|
|
|
|
def fun(x, idx):
|
|
|
|
y[idx] = x
|
|
|
|
return y.sum()
|
|
|
|
|
|
|
|
dfdx = mx.grad(fun)(mx.array([2.0]), mx.array([1]))
|
|
|
|
self.assertTrue(mx.array_equal(dfdx, mx.array([1.0])))
|
|
|
|
self.assertEqual(dfdx.dtype, mx.float32)
|
|
|
|
|
2024-01-15 06:12:15 +08:00
|
|
|
def test_scatter_max_vjp(self):
|
|
|
|
def fun(src, updates):
|
|
|
|
x = src.at[1].maximum(updates)
|
|
|
|
return x
|
|
|
|
|
|
|
|
cotan = mx.array([4.0, 5.0, 6.0])
|
|
|
|
_, vjps = mx.vjp(fun, [mx.array([1.0, 2.0, 3.0]), mx.array([[3.0]])], [cotan])
|
|
|
|
mx.eval(vjps)
|
|
|
|
|
|
|
|
# Update larger than value
|
|
|
|
self.assertTrue(mx.allclose(vjps[0], mx.array([4.0, 0.0, 6.0])))
|
|
|
|
self.assertTrue(mx.allclose(vjps[1], mx.array([5.0])))
|
|
|
|
|
|
|
|
cotan = mx.array([[4.0], [5.0], [6.0]])
|
2024-01-16 16:37:40 +08:00
|
|
|
_, vjps = mx.vjp(
|
|
|
|
fun, [mx.array([[1.0], [2.0], [3.0]]), mx.array([[[2.0]]])], [cotan]
|
|
|
|
)
|
|
|
|
mx.eval(vjps)
|
|
|
|
|
|
|
|
# Update and value are equal
|
|
|
|
self.assertTrue(mx.allclose(vjps[0], mx.array([[4.0], [5.0], [6.0]])))
|
|
|
|
self.assertTrue(mx.allclose(vjps[1], mx.array([[[5.0]]])))
|
|
|
|
|
|
|
|
def test_scatter_min_vjp(self):
|
|
|
|
def fun(src, updates):
|
|
|
|
x = src.at[1].minimum(updates)
|
|
|
|
return x
|
|
|
|
|
|
|
|
cotan = mx.array([4.0, 5.0, 6.0])
|
|
|
|
_, vjps = mx.vjp(fun, [mx.array([1.0, 2.0, 3.0]), mx.array([[3.0]])], [cotan])
|
|
|
|
mx.eval(vjps)
|
|
|
|
|
|
|
|
# Update larger than value
|
|
|
|
self.assertTrue(mx.allclose(vjps[0], mx.array([4.0, 5.0, 6.0])))
|
|
|
|
self.assertTrue(mx.allclose(vjps[1], mx.array([0.0])))
|
|
|
|
|
|
|
|
cotan = mx.array([[4.0], [5.0], [6.0]])
|
2024-01-15 06:12:15 +08:00
|
|
|
_, vjps = mx.vjp(
|
|
|
|
fun, [mx.array([[1.0], [2.0], [3.0]]), mx.array([[[2.0]]])], [cotan]
|
|
|
|
)
|
|
|
|
mx.eval(vjps)
|
|
|
|
|
|
|
|
# Update and value are equal
|
|
|
|
self.assertTrue(mx.allclose(vjps[0], mx.array([[4.0], [5.0], [6.0]])))
|
|
|
|
self.assertTrue(mx.allclose(vjps[1], mx.array([[[5.0]]])))
|
|
|
|
|
2024-01-17 05:33:55 +08:00
|
|
|
def test_split_against_slice(self):
|
|
|
|
def f_split(x):
|
|
|
|
a, _, b = x.split(3, -1)
|
|
|
|
return (a * b).sum()
|
|
|
|
|
|
|
|
def f_slice(x):
|
|
|
|
step = x.shape[-1] // 3
|
|
|
|
a = x[..., :step]
|
|
|
|
b = x[..., -step:]
|
|
|
|
return (a * b).sum()
|
|
|
|
|
|
|
|
x = mx.random.uniform(shape=(100, 300))
|
|
|
|
mx.eval(x)
|
|
|
|
|
|
|
|
df1 = mx.grad(f_split)
|
|
|
|
df2 = mx.grad(f_slice)
|
|
|
|
|
|
|
|
self.assertTrue(mx.allclose(df1(x), df2(x)))
|
|
|
|
|
2024-01-11 05:32:37 +08:00
|
|
|
def test_vjp_types(self):
|
|
|
|
def fun(x):
|
|
|
|
return x
|
|
|
|
|
|
|
|
for t in [mx.float16, mx.bfloat16, mx.float32]:
|
|
|
|
out = mx.grad(fun)(mx.array(1.0, t))
|
|
|
|
self.assertEqual(out.dtype, t)
|
|
|
|
|
|
|
|
def fun(x):
|
|
|
|
return x.sum()
|
|
|
|
|
|
|
|
for t in [mx.float16, mx.bfloat16, mx.float32]:
|
|
|
|
out = mx.grad(fun)(mx.array(1.0, t))
|
|
|
|
self.assertEqual(out.dtype, t)
|
|
|
|
|
|
|
|
def fun(x, y):
|
|
|
|
return (x + y).sum()
|
|
|
|
|
|
|
|
for t in [mx.float16, mx.bfloat16, mx.float32]:
|
|
|
|
out = mx.grad(fun)(mx.array(1.0, t), mx.array(1.0, t))
|
|
|
|
self.assertEqual(out.dtype, t)
|
|
|
|
|
2023-11-30 02:30:41 +08:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
unittest.main()
|