diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index c9f839d4b..ea60774d1 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -1,5 +1,4 @@ // Copyright © 2023-2024 Apple Inc. - #include #include #include @@ -1683,48 +1682,58 @@ std::pair, std::vector> Gather::vmap( auto gather_axes = axes_; auto slice_sizes = slice_sizes_; auto src_vmapped = axes[0] >= 0; - auto indices_vmapped = - std::any_of(axes.begin() + 1, axes.end(), [](int a) { return a >= 0; }); - auto out_ax = - *std::find_if(axes.begin(), axes.end(), [](int a) { return a >= 0; }); + auto ind_vmap_ax_ptr = + std::find_if(axes.begin() + 1, axes.end(), [](int a) { return a >= 0; }); + int out_ax = -1; + bool indices_vmapped = (ind_vmap_ax_ptr != axes.end()); + if (indices_vmapped) { + out_ax = *ind_vmap_ax_ptr; + } else if (src_vmapped) { + out_ax = axes[0]; + } // Reorder all the index arrays so the vmap axis is in the same spot. - for (int i = 1; i < axes.size(); ++i) { - if (out_ax != axes[i] && axes[i] >= 0) { - indices[i - 1] = moveaxis(indices[i - 1], axes[i], out_ax, stream()); + if (indices_vmapped) { + for (int i = 1; i < axes.size(); ++i) { + if (out_ax != axes[i] && axes[i] >= 0) { + indices[i - 1] = moveaxis(indices[i - 1], axes[i], out_ax, stream()); + } else if (axes[i] < 0) { + indices[i - 1] = expand_dims(indices[i - 1], out_ax, stream()); + } } } + int idx_dims = indices.empty() ? 0 : indices[0].ndim(); + if (src_vmapped) { - int max_dims = 0; - for (auto& idx : indices) { - max_dims = std::max(static_cast(idx.ndim()), max_dims); - } - auto new_ax_loc = - std::find_if(gather_axes.begin(), gather_axes.end(), [&out_ax](int a) { - return a >= out_ax; - }); - for (; new_ax_loc < gather_axes.end(); new_ax_loc++) { - (*new_ax_loc)++; + for (auto& ax : gather_axes) { + if (ax >= axes[0]) { + ax++; + } } if (indices_vmapped) { // Make a new index array for the vmapped dimension + auto vmap_inds = arange(0, src.shape(axes[0]), stream()); // Reshape it so it broadcasts with other index arrays + { + auto shape = std::vector(idx_dims, 1); + shape[out_ax] = vmap_inds.size(); + vmap_inds = reshape(vmap_inds, std::move(shape), stream()); + } // Update gather axes and slice sizes accordingly - auto shape = std::vector(max_dims - out_ax, 1); - auto vmap_inds = arange(0, src.shape(out_ax), stream()); - shape[0] = vmap_inds.shape(0); - vmap_inds = reshape(vmap_inds, shape, stream()); - slice_sizes.insert(slice_sizes.begin() + out_ax, 1); - auto new_ax_idx = new_ax_loc - gather_axes.begin(); - gather_axes.insert(new_ax_loc, out_ax); - indices.insert(indices.begin() + new_ax_idx, vmap_inds); + slice_sizes.insert(slice_sizes.begin() + axes[0], 1); + gather_axes.push_back(axes[0]); + indices.push_back(vmap_inds); } else { - slice_sizes.insert(slice_sizes.begin() + axes[0], src.shape(axes[0])); - out_ax = max_dims + axes[0]; + slice_sizes.insert(slice_sizes.begin() + out_ax, src.shape(out_ax)); + out_ax += idx_dims; } } - return {{gather(src, indices, gather_axes, slice_sizes, stream())}, {out_ax}}; + auto out = gather(src, indices, gather_axes, slice_sizes, stream()); + if (src_vmapped && indices_vmapped) { + out = squeeze(out, idx_dims + axes[0], stream()); + } + return {{out}, {out_ax}}; } std::vector Gather::vjp( diff --git a/python/tests/test_vmap.py b/python/tests/test_vmap.py index 866012a12..512865073 100644 --- a/python/tests/test_vmap.py +++ b/python/tests/test_vmap.py @@ -370,6 +370,51 @@ class TestVmap(mlx_tests.MLXTestCase): mx.allclose(a[:, i, :] @ invs[i], mx.eye(a.shape[0]), rtol=0, atol=1e-5) ) + def test_vmap_gather(self): + def gather(a, idx): + return a[idx] + + a = mx.array([[1, 2], [3, 4]]) + idx = mx.array(0) + out = mx.vmap(gather, (0, None))(a, idx) + self.assertTrue(mx.array_equal(out, mx.array([1, 3]))) + + out = mx.vmap(gather, (1, None))(a, idx) + self.assertTrue(mx.array_equal(out, mx.array([1, 2]))) + + idx = mx.array([0, 1]) + out = mx.vmap(gather, (0, 0))(a, idx) + self.assertTrue(mx.array_equal(out, mx.array([1, 4]))) + + a = mx.ones((2, 3, 4)) + idx = mx.zeros(4, mx.int32) + out = mx.vmap(gather, (2, 0))(a, idx) + self.assertEqual(out.shape, (4, 3)) + + f = mx.vmap(gather, (0, None)) + f = mx.vmap(gather, (0, 0)) + out = f(mx.ones((2, 3, 4)), mx.zeros(2, dtype=mx.int32)) + self.assertEqual(out.shape, (2, 4)) + + def gather(a, idxa, idxb): + return a[idxa, idxb] + + a = mx.ones((2, 3, 4)) + idxa = mx.zeros((2, 3), mx.int32) + idxb = mx.zeros(3, mx.int32) + out = mx.vmap(gather, (0, 0, None))(a, idxa, idxb) + self.assertEqual(out.shape, (2, 3)) + + idxa = mx.zeros((3, 1, 2), mx.int32) + idxb = mx.zeros((2, 3, 1, 2), mx.int32) + out = mx.vmap(gather, (0, None, 0))(a, idxa, idxb) + self.assertEqual(out.shape, (2, 3, 1, 2)) + + idxa = mx.zeros((3, 1, 2), mx.int32) + idxb = mx.zeros((3, 1, 2, 2), mx.int32) + out = mx.vmap(gather, (0, None, 3))(a, idxa, idxb) + self.assertEqual(out.shape, (2, 3, 1, 2)) + def test_vmap_scatter(self): def scatter(a): a[mx.array(0)] = mx.array(0.0)