diff --git a/python/src/indexing.cpp b/python/src/indexing.cpp index 59e1ff194..b60b23cfa 100644 --- a/python/src/indexing.cpp +++ b/python/src/indexing.cpp @@ -367,6 +367,7 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) { } bool squeeze_needed = false; + bool unsqueeze_needed = false; // Slice handling { @@ -395,17 +396,19 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) { } axis++; + } else { + unsqueeze_needed = true; } } src = slice(src, starts, ends, strides); } // Unsqueeze handling - if (remaining_indices.size() > src.ndim() || squeeze_needed) { + if (unsqueeze_needed || squeeze_needed) { std::vector out_shape; int axis = 0; for (auto& idx : remaining_indices) { - if (idx.is_none()) { + if (unsqueeze_needed && idx.is_none()) { out_shape.push_back(1); } else if (squeeze_needed && nb::isinstance(idx)) { axis++; diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 03a341f9c..6094e025f 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -763,6 +763,10 @@ class TestArray(mlx_tests.MLXTestCase): a_sliced_npy = np.asarray(a_sliced_mlx) self.assertTrue(np.array_equal(a_sliced_npy, a_npy[None])) + a_sliced_mlx = a_mlx[:, None] + a_sliced_npy = np.asarray(a_sliced_mlx) + self.assertTrue(np.array_equal(a_sliced_npy, a_npy[:, None])) + # Multi dim indexing, all ints self.assertEqual(a_mlx[0, 0].item(), 0) self.assertEqual(a_mlx[0, 0].ndim, 0)