Reduce JVP (#2854)

This commit is contained in:
Awni Hannun
2025-12-02 16:17:47 -08:00
committed by GitHub
parent eff0e31f00
commit d8ceae7b77
3 changed files with 73 additions and 6 deletions

View File

@@ -3846,6 +3846,62 @@ std::vector<array> Reduce::vjp(
} }
} }
std::vector<array> Reduce::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
auto in = primals[0];
auto s = stream();
auto grad_op = [&s, reduce_type = reduce_type_](
const array& x, const array& tan, int axis) {
if (reduce_type == Reduce::Min) {
auto idx = argmin(x, axis, true, s);
return take_along_axis(tan, idx, axis, s);
} else if (reduce_type == Reduce::Max) {
auto idx = argmax(x, axis, true, s);
return take_along_axis(tan, idx, axis, s);
} else {
auto p1 = cumprod(x, axis, /*reverse=*/false, /*inclusive=*/false, s);
auto p2 = cumprod(x, axis, /*reverse=*/true, /*inclusive=*/false, s);
auto out = multiply(multiply(p1, p2, s), tan, s);
return sum(out, axis, true, s);
}
};
auto tan = tangents[0];
if (reduce_type_ == Reduce::Sum) {
return {sum(tan, axes_, true, s)};
} else {
if (axes_.size() > 1) {
std::vector<int> transpose_to;
{
// Find the transpose needed to move axes_ to the back.
int j = 0;
for (int i = 0; i < in.ndim(); i++) {
if (j < axes_.size() && axes_[j] == i) {
j++;
} else {
transpose_to.push_back(i);
}
}
for (auto ax : axes_) {
transpose_to.push_back(ax);
}
}
int start_ax = in.ndim() - axes_.size();
in = flatten(transpose(in, transpose_to, s), start_ax, -1, s);
tan = flatten(transpose(tan, transpose_to, s), start_ax, -1, s);
auto grad = squeeze(grad_op(in, tan, -1), -1, s);
return {expand_dims(grad, axes_, s)};
} else {
return {grad_op(in, tan, axes_[0])};
}
}
}
std::pair<std::vector<array>, std::vector<int>> Reduce::vmap( std::pair<std::vector<array>, std::vector<int>> Reduce::vmap(
const std::vector<array>& inputs, const std::vector<array>& inputs,
const std::vector<int>& axes) { const std::vector<int>& axes) {

View File

@@ -1751,12 +1751,7 @@ class Reduce : public UnaryPrimitive {
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP() DEFINE_VMAP()
DEFINE_GRADS();
std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override; std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;

View File

@@ -798,6 +798,22 @@ class TestAutograd(mlx_tests.MLXTestCase):
grad_fn(model) grad_fn(model)
self.assertEqual(model[1].item(), 2.0) self.assertEqual(model[1].item(), 2.0)
def test_reduce_jvp(self):
a = mx.arange(4)
b = mx.array([3, 2, 1, 0])
out, jout = mx.jvp(mx.sum, primals=(a,), tangents=(b,))
self.assertEqual(jout[0].item(), 6)
out, jout = mx.jvp(mx.prod, primals=(a,), tangents=(b,))
self.assertEqual(jout[0].item(), 18)
out, jout = mx.jvp(mx.min, primals=(a,), tangents=(b,))
self.assertEqual(jout[0].item(), 3)
out, jout = mx.jvp(mx.max, primals=(a,), tangents=(b,))
self.assertEqual(jout[0].item(), 0)
if __name__ == "__main__": if __name__ == "__main__":
mlx_tests.MLXTestRunner() mlx_tests.MLXTestRunner()