mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Reduce JVP (#2854)
This commit is contained in:
@@ -3846,6 +3846,62 @@ std::vector<array> Reduce::vjp(
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<array> Reduce::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
auto in = primals[0];
|
||||
auto s = stream();
|
||||
|
||||
auto grad_op = [&s, reduce_type = reduce_type_](
|
||||
const array& x, const array& tan, int axis) {
|
||||
if (reduce_type == Reduce::Min) {
|
||||
auto idx = argmin(x, axis, true, s);
|
||||
return take_along_axis(tan, idx, axis, s);
|
||||
} else if (reduce_type == Reduce::Max) {
|
||||
auto idx = argmax(x, axis, true, s);
|
||||
return take_along_axis(tan, idx, axis, s);
|
||||
} else {
|
||||
auto p1 = cumprod(x, axis, /*reverse=*/false, /*inclusive=*/false, s);
|
||||
auto p2 = cumprod(x, axis, /*reverse=*/true, /*inclusive=*/false, s);
|
||||
auto out = multiply(multiply(p1, p2, s), tan, s);
|
||||
return sum(out, axis, true, s);
|
||||
}
|
||||
};
|
||||
|
||||
auto tan = tangents[0];
|
||||
if (reduce_type_ == Reduce::Sum) {
|
||||
return {sum(tan, axes_, true, s)};
|
||||
} else {
|
||||
if (axes_.size() > 1) {
|
||||
std::vector<int> transpose_to;
|
||||
{
|
||||
// Find the transpose needed to move axes_ to the back.
|
||||
int j = 0;
|
||||
for (int i = 0; i < in.ndim(); i++) {
|
||||
if (j < axes_.size() && axes_[j] == i) {
|
||||
j++;
|
||||
} else {
|
||||
transpose_to.push_back(i);
|
||||
}
|
||||
}
|
||||
for (auto ax : axes_) {
|
||||
transpose_to.push_back(ax);
|
||||
}
|
||||
}
|
||||
|
||||
int start_ax = in.ndim() - axes_.size();
|
||||
in = flatten(transpose(in, transpose_to, s), start_ax, -1, s);
|
||||
tan = flatten(transpose(tan, transpose_to, s), start_ax, -1, s);
|
||||
|
||||
auto grad = squeeze(grad_op(in, tan, -1), -1, s);
|
||||
return {expand_dims(grad, axes_, s)};
|
||||
} else {
|
||||
return {grad_op(in, tan, axes_[0])};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Reduce::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
|
||||
@@ -1751,12 +1751,7 @@ class Reduce : public UnaryPrimitive {
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
DEFINE_VMAP()
|
||||
|
||||
std::vector<array> vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>& outputs) override;
|
||||
DEFINE_GRADS();
|
||||
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
|
||||
|
||||
@@ -798,6 +798,22 @@ class TestAutograd(mlx_tests.MLXTestCase):
|
||||
grad_fn(model)
|
||||
self.assertEqual(model[1].item(), 2.0)
|
||||
|
||||
def test_reduce_jvp(self):
|
||||
a = mx.arange(4)
|
||||
b = mx.array([3, 2, 1, 0])
|
||||
|
||||
out, jout = mx.jvp(mx.sum, primals=(a,), tangents=(b,))
|
||||
self.assertEqual(jout[0].item(), 6)
|
||||
|
||||
out, jout = mx.jvp(mx.prod, primals=(a,), tangents=(b,))
|
||||
self.assertEqual(jout[0].item(), 18)
|
||||
|
||||
out, jout = mx.jvp(mx.min, primals=(a,), tangents=(b,))
|
||||
self.assertEqual(jout[0].item(), 3)
|
||||
|
||||
out, jout = mx.jvp(mx.max, primals=(a,), tangents=(b,))
|
||||
self.assertEqual(jout[0].item(), 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
||||
Reference in New Issue
Block a user