Reduce JVP (#2854)

This commit is contained in:
Awni Hannun
2025-12-02 16:17:47 -08:00
committed by GitHub
parent eff0e31f00
commit d8ceae7b77
3 changed files with 73 additions and 6 deletions

View File

@@ -798,6 +798,22 @@ class TestAutograd(mlx_tests.MLXTestCase):
grad_fn(model)
self.assertEqual(model[1].item(), 2.0)
def test_reduce_jvp(self):
a = mx.arange(4)
b = mx.array([3, 2, 1, 0])
out, jout = mx.jvp(mx.sum, primals=(a,), tangents=(b,))
self.assertEqual(jout[0].item(), 6)
out, jout = mx.jvp(mx.prod, primals=(a,), tangents=(b,))
self.assertEqual(jout[0].item(), 18)
out, jout = mx.jvp(mx.min, primals=(a,), tangents=(b,))
self.assertEqual(jout[0].item(), 3)
out, jout = mx.jvp(mx.max, primals=(a,), tangents=(b,))
self.assertEqual(jout[0].item(), 0)
if __name__ == "__main__":
mlx_tests.MLXTestRunner()