mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Reduce JVP (#2854)
This commit is contained in:
@@ -798,6 +798,22 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
grad_fn(model)
|
||||
self.assertEqual(model[1].item(), 2.0)
|
||||
|
||||
def test_reduce_jvp(self):
|
||||
a = mx.arange(4)
|
||||
b = mx.array([3, 2, 1, 0])
|
||||
|
||||
out, jout = mx.jvp(mx.sum, primals=(a,), tangents=(b,))
|
||||
self.assertEqual(jout[0].item(), 6)
|
||||
|
||||
out, jout = mx.jvp(mx.prod, primals=(a,), tangents=(b,))
|
||||
self.assertEqual(jout[0].item(), 18)
|
||||
|
||||
out, jout = mx.jvp(mx.min, primals=(a,), tangents=(b,))
|
||||
self.assertEqual(jout[0].item(), 3)
|
||||
|
||||
out, jout = mx.jvp(mx.max, primals=(a,), tangents=(b,))
|
||||
self.assertEqual(jout[0].item(), 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
||||
Reference in New Issue
Block a user