diff --git a/mlx/array.cpp b/mlx/array.cpp index a8c77d150..c43d6a104 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -167,7 +167,7 @@ void array::copy_shared_buffer( const Strides& strides, Flags flags, size_t data_size, - size_t offset /* = 0 */) { + int64_t offset /* = 0 */) { array_desc_->data = other.array_desc_->data; array_desc_->strides = strides; array_desc_->flags = flags; diff --git a/mlx/array.h b/mlx/array.h index c8a529d7d..25b1f5766 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -439,7 +439,7 @@ class array { const Strides& strides, Flags flags, size_t data_size, - size_t offset = 0); + int64_t offset = 0); void copy_shared_buffer(const array& other); diff --git a/mlx/backend/common/slicing.cpp b/mlx/backend/common/slicing.cpp index 38f3c1ba0..cad2fd619 100644 --- a/mlx/backend/common/slicing.cpp +++ b/mlx/backend/common/slicing.cpp @@ -14,17 +14,13 @@ std::tuple prepare_slice( data_offset += start_indices[i] * in.strides()[i]; inp_strides[i] = in.strides()[i] * strides[i]; } - // Normalize the offset - if (data_offset < 0) { - data_offset += in.data_size(); - } return std::make_tuple(data_offset, inp_strides); } void shared_buffer_slice( const array& in, const Strides& out_strides, - size_t data_offset, + int64_t data_offset, size_t data_size, array& out) { // Compute row/col contiguity @@ -51,17 +47,24 @@ void slice( // Calculate out strides, initial offset auto [data_offset, inp_strides] = prepare_slice(in, start_indices, strides); - int64_t data_end = 1; - for (int i = 0; i < start_indices.size(); ++i) { - if (in.shape()[i] > 1) { - auto end_idx = start_indices[i] + out.shape()[i] * strides[i] - 1; - data_end += end_idx * in.strides()[i]; + + // Get the location of the end based on the inp strides and out.shape() + int64_t low_idx = 0; + int64_t high_idx = 0; + for (int i = 0; i < inp_strides.size(); ++i) { + auto delta = inp_strides[i] * (out.shape()[i] - 1); + if (inp_strides[i] > 0) { + high_idx += delta; + } else { + low_idx += delta; } } - if (data_end < 0) { - data_end += in.data_size(); + int64_t data_size = (high_idx - low_idx) + 1; + if (data_size < 0) { + std::ostringstream msg; + msg << "[slice] Computed invalid data size: " << data_size << "."; + throw std::runtime_error(msg.str()); } - size_t data_size = (data_end - data_offset); shared_buffer_slice(in, inp_strides, data_offset, data_size, out); } diff --git a/mlx/backend/gpu/slicing.cpp b/mlx/backend/gpu/slicing.cpp index fde2a01cd..7f0fec27b 100644 --- a/mlx/backend/gpu/slicing.cpp +++ b/mlx/backend/gpu/slicing.cpp @@ -11,7 +11,7 @@ void slice_gpu( array& out, const Shape& start_indices, const Shape& strides, - const Stream& s) { + const Stream&) { slice(in, out, start_indices, strides); } diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 8a353d743..38921e564 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -3058,6 +3058,11 @@ class TestOps(mlx_tests.MLXTestCase): out = a[::-1] self.assertTrue(mx.array_equal(out[-1, :], a[0, :])) + a = mx.arange(8) + for _ in range(4): + a = a[::-1] + self.assertTrue(mx.array_equal(a, mx.arange(8))) + def test_complex_ops(self): x = mx.array( [ diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 38e64559b..a0e0b1547 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -292,7 +292,7 @@ TEST_CASE("test slice") { out = slice(x, {0}, {4}, {2}); eval(out); - CHECK_EQ(out.data_size(), 4); + CHECK_EQ(out.data_size(), 3); x = ones({4, 4}); out = slice(x, {0, 0}, {2, 4}); @@ -325,6 +325,20 @@ TEST_CASE("test slice") { out = slice(x, {2, 2, 2}, {3, 4, 3}); eval(out); CHECK_EQ(out.data_size(), 5); + + x = ones({8}); + out = slice(x, {7}, {-9}, {-1}); + eval(out); + CHECK_EQ(out.data_size(), 8); + + out = slice(x, {7}, {-9}, {-1}); + eval(out); + CHECK_EQ(out.data_size(), 8); + + x = ones({4, 2}); + out = slice(x, {3, 0}, {-5, 2}, {-1, 1}); + eval(out); + CHECK_EQ(out.data_size(), 8); } TEST_CASE("test slice update") {