From 9d2644122459835cab85a08209302d8beef80972 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Mon, 19 Aug 2024 16:05:06 -0700 Subject: [PATCH] Fix contiguity check (#1336) Co-authored-by: Alex Barron --- mlx/backend/common/utils.h | 4 +-- mlx/backend/metal/fft.cpp | 4 +-- tests/ops_tests.cpp | 56 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 60 insertions(+), 4 deletions(-) diff --git a/mlx/backend/common/utils.h b/mlx/backend/common/utils.h index f3082ef748..14252f3783 100644 --- a/mlx/backend/common/utils.h +++ b/mlx/backend/common/utils.h @@ -115,8 +115,8 @@ inline auto check_contiguity( bool is_col_contiguous = true; for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) { - is_row_contiguous &= strides[i] == f_stride || shape[i] == 1; - is_col_contiguous &= strides[ri] == b_stride || shape[ri] == 1; + is_col_contiguous &= strides[i] == f_stride || shape[i] == 1; + is_row_contiguous &= strides[ri] == b_stride || shape[ri] == 1; f_stride *= shape[i]; b_stride *= shape[ri]; if (strides[i] > 0) { diff --git a/mlx/backend/metal/fft.cpp b/mlx/backend/metal/fft.cpp index 8fb2c93778..202c83544a 100644 --- a/mlx/backend/metal/fft.cpp +++ b/mlx/backend/metal/fft.cpp @@ -546,8 +546,8 @@ void fft_op( auto [data_size, is_row_contiguous, is_col_contiguous] = check_contiguity(x.shape(), strides); - flags.col_contiguous = is_row_contiguous; - flags.row_contiguous = is_col_contiguous; + flags.col_contiguous = is_col_contiguous; + flags.row_contiguous = is_row_contiguous; flags.contiguous = data_size == x_copy.size(); x_copy.set_data( diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index fb4fea1509..dc86037cbf 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -100,6 +100,36 @@ TEST_CASE("test reshape") { CHECK_EQ(y.strides()[4], 8); // y.strides()[5] can be anything since y.shape()[5] == 1 CHECK_EQ(x.data(), y.data()); + + // Check contiguity preservation + x = ones({10, 10}); + eval(x); + CHECK(x.flags().row_contiguous); + CHECK(!x.flags().col_contiguous); + y = reshape(x, {2, 5, 10}); + eval(y); + CHECK(y.flags().row_contiguous); + CHECK(!y.flags().col_contiguous); + y = reshape(x, {10, 1, 10, 1}); + eval(y); + CHECK(y.flags().row_contiguous); + CHECK(!y.flags().col_contiguous); + x = transpose(x, {1, 0}); + eval(x); + CHECK(!x.flags().row_contiguous); + CHECK(x.flags().col_contiguous); + y = reshape(x, {2, 5, 10}); + eval(y); + CHECK(!y.flags().row_contiguous); + CHECK(y.flags().col_contiguous); + y = reshape(x, {2, 50}); + eval(y); + CHECK(y.flags().row_contiguous); + CHECK(!y.flags().col_contiguous); + y = reshape(x, {10, 1, 10, 1}); + eval(y); + CHECK(!y.flags().row_contiguous); + CHECK(y.flags().col_contiguous); } TEST_CASE("test flatten") { @@ -196,6 +226,32 @@ TEST_CASE("test slice") { out = slice(x, {0, 0}, {2, 4}, {1, 2}); CHECK(array_equal(out, array({0, 2, 4, 6}, {2, 2})).item()); + + // Check contiguity preservation + x = ones({10, 10}) * 2; + eval(x); + CHECK(x.flags().row_contiguous); + CHECK(!x.flags().col_contiguous); + out = slice(x, {0, 0}, {10, 5}); + eval(out); + CHECK(!out.flags().row_contiguous); + CHECK(!out.flags().col_contiguous); + out = slice(x, {0, 0}, {5, 10}); + eval(out); + CHECK(out.flags().row_contiguous); + CHECK(!out.flags().col_contiguous); + x = transpose(x, {1, 0}); + eval(x); + CHECK(!x.flags().row_contiguous); + CHECK(x.flags().col_contiguous); + out = slice(x, {0, 0}, {10, 5}); + eval(out); + CHECK(!out.flags().row_contiguous); + CHECK(out.flags().col_contiguous); + out = slice(x, {0, 0}, {5, 10}); + eval(out); + CHECK(!out.flags().row_contiguous); + CHECK(!out.flags().col_contiguous); } TEST_CASE("test slice update") {