From 5611e1a95ee4ae16c5ac982c217662fbb42d0d72 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Abdussamet=20T=C3=BCrker?= <53705368+abdussamettrkr@users.noreply.github.com> Date: Tue, 26 Mar 2024 23:59:44 +0300 Subject: [PATCH] Fix unsqueeze with None (#899) * Fix unsqueeze with None * Clean unnecessary files --- python/src/indexing.cpp | 7 +++++-- python/tests/test_array.py | 4 ++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/python/src/indexing.cpp b/python/src/indexing.cpp index 59e1ff194..b60b23cfa 100644 --- a/python/src/indexing.cpp +++ b/python/src/indexing.cpp @@ -367,6 +367,7 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) { } bool squeeze_needed = false; + bool unsqueeze_needed = false; // Slice handling { @@ -395,17 +396,19 @@ array mlx_get_item_nd(array src, const nb::tuple& entries) { } axis++; + } else { + unsqueeze_needed = true; } } src = slice(src, starts, ends, strides); } // Unsqueeze handling - if (remaining_indices.size() > src.ndim() || squeeze_needed) { + if (unsqueeze_needed || squeeze_needed) { std::vector out_shape; int axis = 0; for (auto& idx : remaining_indices) { - if (idx.is_none()) { + if (unsqueeze_needed && idx.is_none()) { out_shape.push_back(1); } else if (squeeze_needed && nb::isinstance(idx)) { axis++; diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 03a341f9c..6094e025f 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -763,6 +763,10 @@ class TestArray(mlx_tests.MLXTestCase): a_sliced_npy = np.asarray(a_sliced_mlx) self.assertTrue(np.array_equal(a_sliced_npy, a_npy[None])) + a_sliced_mlx = a_mlx[:, None] + a_sliced_npy = np.asarray(a_sliced_mlx) + self.assertTrue(np.array_equal(a_sliced_npy, a_npy[:, None])) + # Multi dim indexing, all ints self.assertEqual(a_mlx[0, 0].item(), 0) self.assertEqual(a_mlx[0, 0].ndim, 0)