From 5e89aace9b8d69eff2ca4553a100b08bf703ec39 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 19 Nov 2024 10:44:04 -0800 Subject: [PATCH] Fix concatenate vmap (#1600) --- mlx/primitives.cpp | 34 +++++++++++++++++++++++----------- python/tests/test_vmap.py | 20 ++++++++++++++++++++ 2 files changed, 43 insertions(+), 11 deletions(-) diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index ea60774d1..743b695dd 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -836,31 +836,43 @@ std::vector Concatenate::jvp( std::pair, std::vector> Concatenate::vmap( const std::vector& inputs, const std::vector& axes) { - std::vector t_inputs; int out_ax = -1; + int first_vmap = -1; + // Find the first vmapped input - int i = 0; - for (; i < axes.size(); i++) { - t_inputs.push_back(inputs[i]); + for (int i = 0; i < axes.size(); i++) { if (axes[i] >= 0) { out_ax = axes[i]; + first_vmap = i; break; } } - if (out_ax >= 0) { - // Advance to the next input - i++; - // Move vmap axes to the same spot. - for (; i < axes.size(); ++i) { - if (out_ax != axes[i] && axes[i] >= 0) { + // No vmap, should we even be in here? + if (out_ax < 0) { + return {{concatenate(inputs, axis_, stream())}, {out_ax}}; + } + + // Make sure vmapped arrays have all vmapped axes in the same location and + // expand non-vmapped arrays to be compatible with the vmapped ones. + std::vector t_inputs; + int N = inputs[first_vmap].shape(out_ax); + int axis = axis_ + (axis_ >= out_ax); + auto cat_shape = inputs[first_vmap].shape(); + for (int i = 0; i < axes.size(); i++) { + if (axes[i] >= 0) { + if (out_ax != axes[i]) { t_inputs.push_back(moveaxis(inputs[i], axes[i], out_ax, stream())); } else { t_inputs.push_back(inputs[i]); } + } else { + cat_shape[axis] = inputs[i].shape(axis_); + t_inputs.push_back(broadcast_to( + expand_dims(inputs[i], out_ax, stream()), cat_shape, stream())); } } - auto axis = axis_ + (out_ax >= 0 && axis_ >= out_ax); + return {{concatenate(t_inputs, axis, stream())}, {out_ax}}; } diff --git a/python/tests/test_vmap.py b/python/tests/test_vmap.py index 512865073..f71c500e2 100644 --- a/python/tests/test_vmap.py +++ b/python/tests/test_vmap.py @@ -527,6 +527,26 @@ class TestVmap(mlx_tests.MLXTestCase): with self.assertRaises(ValueError): out = mx.vmap(const_func, in_axes=(0, 0))(a, b) + def test_vmap_concatenate(self): + x = mx.random.uniform(shape=(2, 2, 2)) + + def cat_fun(x, y): + return mx.concatenate([x, y], axis=1) + + def cat_constant(x): + y = mx.ones((2, 1)) + return mx.concatenate([x, y], 1) + + out = mx.vmap(cat_fun, in_axes=(0, 2))(x, x) + target = mx.stack( + [mx.concatenate([x[i], x[:, :, i]], axis=1) for i in range(2)] + ) + self.assertTrue(mx.array_equal(out, target)) + + out = mx.vmap(cat_constant)(x) + target = mx.concatenate([x, mx.ones((2, 2, 1))], axis=2) + self.assertTrue(mx.array_equal(out, target)) + if __name__ == "__main__": unittest.main()