From 90d04072b7ceee55fa63ab76381d4c0c20189af8 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sun, 17 Dec 2023 11:58:45 -0800 Subject: [PATCH] fix build w/ flatten (#195) --- mlx/ops.cpp | 31 +++++++++++++++++++++++-------- python/tests/test_array.py | 2 +- tests/ops_tests.cpp | 6 ++++++ 3 files changed, 30 insertions(+), 9 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 19b165f3a..147c2c111 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -283,21 +283,36 @@ array flatten( int end_axis /* = -1 */, StreamOrDevice s /* = {} */) { auto ndim = static_cast(a.ndim()); - start_axis += (start_axis < 0 ? ndim : 0); - end_axis += (end_axis < 0 ? ndim + 1 : 0); - start_axis = std::max(0, start_axis); - end_axis = std::min(ndim, end_axis); - if (end_axis < start_axis) { + auto start_ax = start_axis + (start_axis < 0 ? ndim : 0); + auto end_ax = end_axis + (end_axis < 0 ? ndim : 0); + start_ax = std::max(0, start_ax); + end_ax = std::min(ndim - 1, end_ax); + if (a.ndim() == 0) { + return reshape(a, {1}, s); + } + if (end_ax < start_ax) { throw std::invalid_argument( "[flatten] start_axis must be less than or equal to end_axis"); } - if (start_axis == end_axis and a.ndim() != 0) { + if (start_ax >= ndim) { + std::ostringstream msg; + msg << "[flatten] Invalid start_axis " << start_axis << " for array with " + << ndim << " dimensions."; + throw std::invalid_argument(msg.str()); + } + if (end_ax < 0) { + std::ostringstream msg; + msg << "[flatten] Invalid end_axis " << end_axis << " for array with " + << ndim << " dimensions."; + throw std::invalid_argument(msg.str()); + } + if (start_ax == end_ax) { return a; } - std::vector new_shape(a.shape().begin(), a.shape().begin() + start_axis); + std::vector new_shape(a.shape().begin(), a.shape().begin() + start_ax); new_shape.push_back(-1); new_shape.insert( - new_shape.end(), a.shape().begin() + end_axis + 1, a.shape().end()); + new_shape.end(), a.shape().begin() + end_ax + 1, a.shape().end()); return reshape(a, new_shape, s); } diff --git a/python/tests/test_array.py b/python/tests/test_array.py index cabc6a114..85f1aa257 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -906,7 +906,7 @@ class TestArray(mlx_tests.MLXTestCase): # Check slice assign with negative indices works a = mx.zeros((5, 5), mx.int32) a[2:-2, 2:-2] = 4 - self.assertEquals(a[2, 2].item(), 4) + self.assertEqual(a[2, 2].item(), 4) def test_slice_negative_step(self): a_np = np.arange(20) diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index d33c7340a..af53ce8a1 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -73,6 +73,12 @@ TEST_CASE("test flatten") { // Check start > end throws CHECK_THROWS(flatten(x, 2, 1)); + // Check start >= ndim throws + CHECK_THROWS(flatten(x, 5, 6)); + + // Check end < 0 throws + CHECK_THROWS(flatten(x, -5, -4)); + // Check scalar flattens to 1D x = array(1); CHECK_EQ(flatten(x, -3, -1).shape(), std::vector({1}));