matmul jvps (#1772)

This commit is contained in:
Awni Hannun
2025-01-17 10:36:26 -08:00
committed by GitHub
parent f288db8d34
commit 0c259961ac
4 changed files with 138 additions and 21 deletions

View File

@@ -634,6 +634,41 @@ class TestAutograd(mlx_tests.MLXTestCase):
self.assertEqual(grads[0].dtype, mx.float32)
self.assertEqual(grads[1].dtype, mx.float16)
def test_matmul_jvps(self):
a = mx.random.uniform(shape=(4, 4))
b = mx.random.uniform(shape=(4, 4))
c = mx.random.uniform(shape=(4, 4))
d = mx.random.uniform(shape=(4, 4))
_, tangent = mx.jvp(lambda a: a @ b, (a,), (c,))
self.assertTrue(mx.allclose(tangent[0], c @ b))
_, tangent = mx.jvp(lambda b: a @ b, (b,), (d,))
self.assertTrue(mx.allclose(tangent[0], a @ d))
_, tangent = mx.jvp(lambda a, b: a @ b, (a, b), (c, d))
self.assertTrue(mx.allclose(tangent[0], a @ d + c @ b))
x = mx.random.uniform(shape=(4, 4))
y = mx.random.uniform(shape=(4, 4))
z = mx.random.uniform(shape=(4, 4))
_, (tangent,) = mx.jvp(lambda a, b, c: a @ b + c, (a, b, c), (x, y, z))
_, (expected,) = mx.jvp(lambda a, b, c: mx.addmm(c, a, b), (a, b, c), (x, y, z))
self.assertTrue(mx.allclose(tangent, expected))
_, (tangent,) = mx.jvp(lambda a, c: a @ b + c, (a, c), (x, z))
_, (expected,) = mx.jvp(lambda a, c: mx.addmm(c, a, b), (a, c), (x, z))
self.assertTrue(mx.allclose(tangent, expected))
_, (tangent,) = mx.jvp(lambda b, c: a @ b + c, (b, c), (y, z))
_, (expected,) = mx.jvp(lambda b, c: mx.addmm(c, a, b), (b, c), (y, z))
self.assertTrue(mx.allclose(tangent, expected))
_, (tangent,) = mx.jvp(lambda c: a @ b + c, (c,), (z,))
_, (expected,) = mx.jvp(lambda c: mx.addmm(c, a, b), (c,), (z,))
self.assertTrue(mx.allclose(tangent, expected))
if __name__ == "__main__":
unittest.main()