From 7cca1727af6deb61654846c0332abe783bdb2aff Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 4 Sep 2024 19:10:43 -0700 Subject: [PATCH] Fix slice data size (#1394) * fix slice data size and add tests * fix contiguous flag * simplify stride and perform copy for non-contiguous arrays * fix cpu * comment --- mlx/array.h | 27 +++++++++++--- mlx/backend/common/primitives.cpp | 10 +++++- mlx/backend/common/slicing.cpp | 22 +++--------- mlx/backend/common/slicing.h | 5 +-- mlx/backend/common/utils.h | 7 ++-- mlx/backend/metal/normalization.cpp | 8 ++--- mlx/backend/metal/slicing.cpp | 14 ++++++-- mlx/backend/metal/slicing.h | 4 +-- mlx/backend/metal/softmax.cpp | 4 +-- mlx/ops.cpp | 4 +++ python/tests/test_fast.py | 8 +++++ tests/ops_tests.cpp | 55 ++++++++++++++++++++++++++++- 12 files changed, 129 insertions(+), 39 deletions(-) diff --git a/mlx/array.h b/mlx/array.h index 3f000e9b2..cce64e79c 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -219,11 +219,23 @@ class array { }; struct Flags { - // True if there are no gaps in the underlying data. Each item + // True iff there are no gaps in the underlying data. Each item // in the underlying data buffer belongs to at least one index. + // + // True iff: + // prod(shape[i] for i in range(ndim) if strides[i] > 0) == data_size() bool contiguous : 1; + // True iff: + // strides[-1] == 1 and + // all(strides[i] == (shape[i+1]*strides[i+1]) or shape[i] == 1 for i in + // range(ndim - 1)) bool row_contiguous : 1; + + // True iff: + // strides[0] == 1 and + // all(strides[i] == (shape[i-1]*strides[i-1]) or shape[i] == 1 for i in + // range(1, ndim)) bool col_contiguous : 1; }; @@ -291,7 +303,16 @@ class array { return array_desc_->flags; } - /** The size (in elements) of the underlying buffer the array points to. */ + /** The size (in elements) of the underlying buffer the array points to. + * + * This can be different than the actual size of the array if the array has + * been broadcast or irregularly strided. If ``first`` is the offset into + * the data buffer of the first element of the array (i.e. the offset + * corresponding to ``arr[0, 0, ...]``) and last is the offset into the + * data buffer of the last element of the array (i.e. the offset + * corresponding to ``arr[-1, -1, ...]``) then ``data_size = last - first``. + * Note, ``data_size`` is in units of ``item_size`` (not bytes). + **/ size_t data_size() const { return array_desc_->data_size; } @@ -412,8 +433,6 @@ class array { void* data_ptr{nullptr}; // The size in elements of the data buffer the array accesses - // This can be different than the actual size of the array if it - // has been broadcast or irregularly strided. size_t data_size; // Contains useful meta data about the array diff --git a/mlx/backend/common/primitives.cpp b/mlx/backend/common/primitives.cpp index 99adfaed4..14aa52bad 100644 --- a/mlx/backend/common/primitives.cpp +++ b/mlx/backend/common/primitives.cpp @@ -505,8 +505,16 @@ void Slice::eval(const std::vector& inputs, array& out) { /* int64_t o_offset = */ 0, /* CopyType ctype = */ CopyType::General); } else { + size_t data_end = 1; + for (int i = 0; i < end_indices_.size(); ++i) { + if (in.shape()[i] > 1) { + auto end_idx = start_indices_[i] + out.shape()[i] * strides_[i] - 1; + data_end += end_idx * in.strides()[i]; + } + } + size_t data_size = data_end - data_offset; std::vector ostrides{inp_strides.begin(), inp_strides.end()}; - shared_buffer_slice(in, ostrides, data_offset, out); + shared_buffer_slice(in, ostrides, data_offset, data_size, out); } } diff --git a/mlx/backend/common/slicing.cpp b/mlx/backend/common/slicing.cpp index 3425179e0..015cfad5f 100644 --- a/mlx/backend/common/slicing.cpp +++ b/mlx/backend/common/slicing.cpp @@ -6,18 +6,16 @@ namespace mlx::core { std::tuple> prepare_slice( const array& in, - std::vector& start_indices, - std::vector& strides) { + const std::vector& start_indices, + const std::vector& strides) { int64_t data_offset = 0; bool copy_needed = false; std::vector inp_strides(in.ndim(), 0); for (int i = 0; i < in.ndim(); ++i) { data_offset += start_indices[i] * in.strides()[i]; inp_strides[i] = in.strides()[i] * strides[i]; - copy_needed |= strides[i] < 0; } - return std::make_tuple(copy_needed, data_offset, inp_strides); } @@ -25,26 +23,16 @@ void shared_buffer_slice( const array& in, const std::vector& out_strides, size_t data_offset, + size_t data_size, array& out) { // Compute row/col contiguity - auto [data_size, is_row_contiguous, is_col_contiguous] = + auto [no_bsx_size, is_row_contiguous, is_col_contiguous] = check_contiguity(out.shape(), out_strides); auto flags = in.flags(); flags.row_contiguous = is_row_contiguous; flags.col_contiguous = is_col_contiguous; - - if (data_size == 1) { - // Broadcasted scalar array is contiguous. - flags.contiguous = true; - } else if (data_size == in.data_size()) { - // Means we sliced a broadcasted dimension so leave the "no holes" flag - // alone. - } else { - // We sliced something. So either we are row or col contiguous or we - // punched a hole. - flags.contiguous &= flags.row_contiguous || flags.col_contiguous; - } + flags.contiguous = (no_bsx_size == data_size); out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset); } diff --git a/mlx/backend/common/slicing.h b/mlx/backend/common/slicing.h index 842793783..9ee8216f4 100644 --- a/mlx/backend/common/slicing.h +++ b/mlx/backend/common/slicing.h @@ -8,13 +8,14 @@ namespace mlx::core { std::tuple> prepare_slice( const array& in, - std::vector& start_indices, - std::vector& strides); + const std::vector& start_indices, + const std::vector& strides); void shared_buffer_slice( const array& in, const std::vector& out_strides, size_t data_offset, + size_t data_size, array& out); } // namespace mlx::core diff --git a/mlx/backend/common/utils.h b/mlx/backend/common/utils.h index f8d2c9117..0eedfceec 100644 --- a/mlx/backend/common/utils.h +++ b/mlx/backend/common/utils.h @@ -135,7 +135,7 @@ template inline auto check_contiguity( const std::vector& shape, const std::vector& strides) { - size_t data_size = 1; + size_t no_broadcast_data_size = 1; size_t f_stride = 1; size_t b_stride = 1; bool is_row_contiguous = true; @@ -147,11 +147,12 @@ inline auto check_contiguity( f_stride *= shape[i]; b_stride *= shape[ri]; if (strides[i] > 0) { - data_size *= shape[i]; + no_broadcast_data_size *= shape[i]; } } - return std::make_tuple(data_size, is_row_contiguous, is_col_contiguous); + return std::make_tuple( + no_broadcast_data_size, is_row_contiguous, is_col_contiguous); } } // namespace mlx::core diff --git a/mlx/backend/metal/normalization.cpp b/mlx/backend/metal/normalization.cpp index 40f433fa6..8338c0dbe 100644 --- a/mlx/backend/metal/normalization.cpp +++ b/mlx/backend/metal/normalization.cpp @@ -20,8 +20,8 @@ void RMSNorm::eval_gpu( // Make sure that the last dimension is contiguous std::vector copies; auto check_input = [&copies, &s](const array& x) -> const array& { - bool no_copy = x.strides()[x.ndim() - 1] == 1; - if (x.ndim() > 1) { + bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1; + if (no_copy && x.ndim() > 1) { auto s = x.strides()[x.ndim() - 2]; no_copy &= (s == 0 || s == x.shape().back()); } @@ -208,8 +208,8 @@ void LayerNorm::eval_gpu( // Make sure that the last dimension is contiguous std::vector copies; auto check_input = [&copies, &s](const array& x) -> const array& { - bool no_copy = x.strides()[x.ndim() - 1] == 1; - if (x.ndim() > 1) { + bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1; + if (no_copy && x.ndim() > 1) { auto s = x.strides()[x.ndim() - 2]; no_copy &= (s == 0 || s == x.shape().back()); } diff --git a/mlx/backend/metal/slicing.cpp b/mlx/backend/metal/slicing.cpp index 1057f27b7..05f051438 100644 --- a/mlx/backend/metal/slicing.cpp +++ b/mlx/backend/metal/slicing.cpp @@ -11,8 +11,8 @@ namespace mlx::core { void slice_gpu( const array& in, array& out, - std::vector start_indices, - std::vector strides, + const std::vector& start_indices, + const std::vector& strides, const Stream& s) { // Calculate out strides, initial offset and if copy needs to be made auto [copy_needed, data_offset, inp_strides] = @@ -34,7 +34,15 @@ void slice_gpu( /* const Stream& s = */ s); } else { std::vector ostrides{inp_strides.begin(), inp_strides.end()}; - shared_buffer_slice(in, ostrides, data_offset, out); + size_t data_end = 1; + for (int i = 0; i < strides.size(); ++i) { + if (in.shape()[i] > 1) { + auto end_idx = start_indices[i] + out.shape()[i] * strides[i] - 1; + data_end += end_idx * in.strides()[i]; + } + } + size_t data_size = data_end - data_offset; + shared_buffer_slice(in, ostrides, data_offset, data_size, out); } } diff --git a/mlx/backend/metal/slicing.h b/mlx/backend/metal/slicing.h index cf81b7000..51da8b54c 100644 --- a/mlx/backend/metal/slicing.h +++ b/mlx/backend/metal/slicing.h @@ -9,8 +9,8 @@ namespace mlx::core { void slice_gpu( const array& in, array& out, - std::vector start_indices, - std::vector strides, + const std::vector& start_indices, + const std::vector& strides, const Stream& s); void concatenate_gpu( diff --git a/mlx/backend/metal/softmax.cpp b/mlx/backend/metal/softmax.cpp index cc25ac7f3..706343ff7 100644 --- a/mlx/backend/metal/softmax.cpp +++ b/mlx/backend/metal/softmax.cpp @@ -24,8 +24,8 @@ void Softmax::eval_gpu(const std::vector& inputs, array& out) { // Make sure that the last dimension is contiguous std::vector copies; auto check_input = [&copies, &s](const array& x) -> const array& { - bool no_copy = x.strides()[x.ndim() - 1] == 1; - if (x.ndim() > 1) { + bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1; + if (no_copy && x.ndim() > 1) { auto s = x.strides()[x.ndim() - 2]; no_copy &= (s == 0 || s == x.shape().back()); } diff --git a/mlx/ops.cpp b/mlx/ops.cpp index ee9aed212..56f8fa234 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -615,6 +615,10 @@ inline auto normalize_slice( out_shape[i] = (stop[i] - start[i] + strides[i] - 1) / strides[i]; } + // Simplify the stride if it's unused + if (out_shape[i] == 1) { + strides[i] = 1; + } } return std::make_pair(has_neg_strides, out_shape); diff --git a/python/tests/test_fast.py b/python/tests/test_fast.py index 93dc2f261..c881eced6 100644 --- a/python/tests/test_fast.py +++ b/python/tests/test_fast.py @@ -429,6 +429,14 @@ class TestFast(mlx_tests.MLXTestCase): rx_fast = mx.fast.layer_norm(x, None, None, eps) self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) + def test_slice_into_layer_norm(self): + dim = 128 + eps = 1e-5 + x = mx.random.uniform(shape=(8, 100, 128))[:, 99:] + rx_fast = mx.fast.layer_norm(x, weight=None, bias=None, eps=eps) + rx = layer_norm(x, None, None, eps) + self.assertLess(mx.abs(rx - rx_fast).max(), 1e-4) + def test_layer_norm_grad(self): D = 32 eps = 1e-5 diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 33d77c626..0333c04e4 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -228,7 +228,7 @@ TEST_CASE("test slice") { CHECK(array_equal(out, array({0, 2, 4, 6}, {2, 2})).item()); // Check contiguity preservation - x = ones({10, 10}) * 2; + x = ones({10, 10}); eval(x); CHECK(x.flags().row_contiguous); CHECK(!x.flags().col_contiguous); @@ -252,6 +252,59 @@ TEST_CASE("test slice") { eval(out); CHECK(!out.flags().row_contiguous); CHECK(!out.flags().col_contiguous); + + x = ones({6, 4, 10}); + out = slice(x, {0, 0, 0}, {6, 4, 10}, {2, 1, 2}); + eval(out); + CHECK(!out.flags().contiguous); + CHECK(!out.flags().row_contiguous); + CHECK(!out.flags().col_contiguous); + + // Check data size correctness + x = ones({4}); + out = slice(x, {0}, {2}); + eval(out); + CHECK_EQ(out.data_size(), 2); + + out = slice(x, {2}, {4}); + eval(out); + CHECK_EQ(out.data_size(), 2); + + out = slice(x, {0}, {4}, {2}); + eval(out); + CHECK_EQ(out.data_size(), 4); + + x = ones({4, 4}); + out = slice(x, {0, 0}, {2, 4}); + eval(out); + CHECK_EQ(out.data_size(), 8); + + out = slice(x, {0, 0}, {1, 2}); + eval(out); + CHECK_EQ(out.data_size(), 2); + + out = slice(x, {0, 1}, {4, 4}); + eval(out); + CHECK_EQ(out.data_size(), 15); + + out = slice(x, {1, 2}, {3, 4}); + eval(out); + CHECK_EQ(out.data_size(), 6); + + x = ones({4, 4, 4}); + out = slice(x, {0, 0, 0}, {4, 2, 2}); + eval(out); + CHECK_EQ(out.data_size(), 54); + + x = ones({4, 4, 4}); + out = slice(x, {2, 2, 2}, {3, 3, 3}); + eval(out); + CHECK_EQ(out.data_size(), 1); + + x = ones({4, 4, 4}); + out = slice(x, {2, 2, 2}, {3, 4, 3}); + eval(out); + CHECK_EQ(out.data_size(), 5); } TEST_CASE("test slice update") {