diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index fa7b384f5..9603e3cf1 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -1716,7 +1716,7 @@ std::vector Unflatten::vjp( const std::vector& cotangents, const std::vector&, const std::vector&) { - return {flatten(cotangents[0], axis_, axis_ + shape_.size(), stream())}; + return {flatten(cotangents[0], axis_, axis_ + shape_.size() - 1, stream())}; } std::vector Unflatten::jvp( diff --git a/python/tests/test_autograd.py b/python/tests/test_autograd.py index 5f3b62a8a..a1106b2a4 100644 --- a/python/tests/test_autograd.py +++ b/python/tests/test_autograd.py @@ -605,6 +605,21 @@ class TestAutograd(mlx_tests.MLXTestCase): dfdx = mx.grad(fun)(x) self.assertTrue(mx.allclose(dfdx, -2j * mx.ones_like(x))) + def test_flatten_unflatten_vjps(self): + def fun(x): + y = mx.unflatten(x, 0, (2, 2)) + return y.sum() + + x = mx.zeros((4, 8)) + self.assertEqual(mx.grad(fun)(x).shape, (4, 8)) + + def fun(x): + y = mx.flatten(x, 0, 2) + return y.sum() + + x = mx.zeros((2, 4, 8)) + self.assertEqual(mx.grad(fun)(x).shape, (2, 4, 8)) + if __name__ == "__main__": unittest.main()