From 9eb7d7362f72e83fa1e5870db1498323a7be2210 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Sat, 8 Feb 2025 09:22:13 -0800 Subject: [PATCH] Fix Split::vmap (#1845) --- mlx/primitives.cpp | 4 +++- python/tests/test_vmap.py | 12 ++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 90ae57906..16acfdc9c 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -4273,7 +4273,9 @@ std::pair, std::vector> Split::vmap( const std::vector& inputs, const std::vector& axes) { int axis_left = axes[0] >= 0 && axes[0] <= axis_; - return {{split(inputs[0], indices_, axis_ + axis_left, stream())}, axes}; + auto output = split(inputs[0], indices_, axis_ + axis_left, stream()); + std::vector output_axes(output.size(), axes[0]); + return {std::move(output), std::move(output_axes)}; } std::vector Split::vjp( diff --git a/python/tests/test_vmap.py b/python/tests/test_vmap.py index b98bdb0fc..ceadf36a4 100644 --- a/python/tests/test_vmap.py +++ b/python/tests/test_vmap.py @@ -596,6 +596,18 @@ class TestVmap(mlx_tests.MLXTestCase): out = mx.vmap(fun, in_axes=(None, 1, 1))(a, idx, upd) self.assertEqual(out.shape, (4, 5, 1)) + def test_vmap_split_vmap(self): + def fun(x): + a, b = mx.split(x, 2, 1) + return mx.concatenate([b, a], 1) + + x = mx.ones((5, 6, 7)) + y = mx.ones((5, 4, 6, 7)) + fx = fun(x) + fy = mx.vmap(fun, in_axes=1)(y) + self.assertEqual(fx.shape, (5, 6, 7)) + self.assertEqual(fy.shape, (4, 5, 6, 7)) + if __name__ == "__main__": unittest.main()