From d8ceae7b7786840b1f54ee36f066260026847413 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 2 Dec 2025 16:17:47 -0800 Subject: [PATCH] Reduce JVP (#2854) --- mlx/primitives.cpp | 56 +++++++++++++++++++++++++++++++++++ mlx/primitives.h | 7 +---- python/tests/test_autograd.py | 16 ++++++++++ 3 files changed, 73 insertions(+), 6 deletions(-) diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 5d96301bd..1b8fbc9b6 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3846,6 +3846,62 @@ std::vector Reduce::vjp( } } +std::vector Reduce::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + auto in = primals[0]; + auto s = stream(); + + auto grad_op = [&s, reduce_type = reduce_type_]( + const array& x, const array& tan, int axis) { + if (reduce_type == Reduce::Min) { + auto idx = argmin(x, axis, true, s); + return take_along_axis(tan, idx, axis, s); + } else if (reduce_type == Reduce::Max) { + auto idx = argmax(x, axis, true, s); + return take_along_axis(tan, idx, axis, s); + } else { + auto p1 = cumprod(x, axis, /*reverse=*/false, /*inclusive=*/false, s); + auto p2 = cumprod(x, axis, /*reverse=*/true, /*inclusive=*/false, s); + auto out = multiply(multiply(p1, p2, s), tan, s); + return sum(out, axis, true, s); + } + }; + + auto tan = tangents[0]; + if (reduce_type_ == Reduce::Sum) { + return {sum(tan, axes_, true, s)}; + } else { + if (axes_.size() > 1) { + std::vector transpose_to; + { + // Find the transpose needed to move axes_ to the back. + int j = 0; + for (int i = 0; i < in.ndim(); i++) { + if (j < axes_.size() && axes_[j] == i) { + j++; + } else { + transpose_to.push_back(i); + } + } + for (auto ax : axes_) { + transpose_to.push_back(ax); + } + } + + int start_ax = in.ndim() - axes_.size(); + in = flatten(transpose(in, transpose_to, s), start_ax, -1, s); + tan = flatten(transpose(tan, transpose_to, s), start_ax, -1, s); + + auto grad = squeeze(grad_op(in, tan, -1), -1, s); + return {expand_dims(grad, axes_, s)}; + } else { + return {grad_op(in, tan, axes_[0])}; + } + } +} + std::pair, std::vector> Reduce::vmap( const std::vector& inputs, const std::vector& axes) { diff --git a/mlx/primitives.h b/mlx/primitives.h index 3d9f60b42..debd068ed 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1751,12 +1751,7 @@ class Reduce : public UnaryPrimitive { void eval_gpu(const std::vector& inputs, array& out) override; DEFINE_VMAP() - - std::vector vjp( - const std::vector& primals, - const std::vector& cotangents, - const std::vector& argnums, - const std::vector& outputs) override; + DEFINE_GRADS(); std::vector output_shapes(const std::vector& inputs) override; diff --git a/python/tests/test_autograd.py b/python/tests/test_autograd.py index 38bb6089d..9eadbd2ef 100644 --- a/python/tests/test_autograd.py +++ b/python/tests/test_autograd.py @@ -798,6 +798,22 @@ class TestAutograd(mlx_tests.MLXTestCase): grad_fn(model) self.assertEqual(model[1].item(), 2.0) + def test_reduce_jvp(self): + a = mx.arange(4) + b = mx.array([3, 2, 1, 0]) + + out, jout = mx.jvp(mx.sum, primals=(a,), tangents=(b,)) + self.assertEqual(jout[0].item(), 6) + + out, jout = mx.jvp(mx.prod, primals=(a,), tangents=(b,)) + self.assertEqual(jout[0].item(), 18) + + out, jout = mx.jvp(mx.min, primals=(a,), tangents=(b,)) + self.assertEqual(jout[0].item(), 3) + + out, jout = mx.jvp(mx.max, primals=(a,), tangents=(b,)) + self.assertEqual(jout[0].item(), 0) + if __name__ == "__main__": mlx_tests.MLXTestRunner()