mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 08:38:09 +08:00
matmul jvps (#1772)
This commit is contained in:
@@ -634,6 +634,41 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(grads[0].dtype, mx.float32)
|
||||
self.assertEqual(grads[1].dtype, mx.float16)
|
||||
|
||||
def test_matmul_jvps(self):
|
||||
a = mx.random.uniform(shape=(4, 4))
|
||||
b = mx.random.uniform(shape=(4, 4))
|
||||
c = mx.random.uniform(shape=(4, 4))
|
||||
d = mx.random.uniform(shape=(4, 4))
|
||||
|
||||
_, tangent = mx.jvp(lambda a: a @ b, (a,), (c,))
|
||||
self.assertTrue(mx.allclose(tangent[0], c @ b))
|
||||
|
||||
_, tangent = mx.jvp(lambda b: a @ b, (b,), (d,))
|
||||
self.assertTrue(mx.allclose(tangent[0], a @ d))
|
||||
|
||||
_, tangent = mx.jvp(lambda a, b: a @ b, (a, b), (c, d))
|
||||
self.assertTrue(mx.allclose(tangent[0], a @ d + c @ b))
|
||||
|
||||
x = mx.random.uniform(shape=(4, 4))
|
||||
y = mx.random.uniform(shape=(4, 4))
|
||||
z = mx.random.uniform(shape=(4, 4))
|
||||
|
||||
_, (tangent,) = mx.jvp(lambda a, b, c: a @ b + c, (a, b, c), (x, y, z))
|
||||
_, (expected,) = mx.jvp(lambda a, b, c: mx.addmm(c, a, b), (a, b, c), (x, y, z))
|
||||
self.assertTrue(mx.allclose(tangent, expected))
|
||||
|
||||
_, (tangent,) = mx.jvp(lambda a, c: a @ b + c, (a, c), (x, z))
|
||||
_, (expected,) = mx.jvp(lambda a, c: mx.addmm(c, a, b), (a, c), (x, z))
|
||||
self.assertTrue(mx.allclose(tangent, expected))
|
||||
|
||||
_, (tangent,) = mx.jvp(lambda b, c: a @ b + c, (b, c), (y, z))
|
||||
_, (expected,) = mx.jvp(lambda b, c: mx.addmm(c, a, b), (b, c), (y, z))
|
||||
self.assertTrue(mx.allclose(tangent, expected))
|
||||
|
||||
_, (tangent,) = mx.jvp(lambda c: a @ b + c, (c,), (z,))
|
||||
_, (expected,) = mx.jvp(lambda c: mx.addmm(c, a, b), (c,), (z,))
|
||||
self.assertTrue(mx.allclose(tangent, expected))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user