mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-01 16:26:49 +08:00
Fix contiguity check (#1336)
Co-authored-by: Alex Barron <abarron22@apple.com>
This commit is contained in:
parent
f12f24a77c
commit
9d26441224
@ -115,8 +115,8 @@ inline auto check_contiguity(
|
||||
bool is_col_contiguous = true;
|
||||
|
||||
for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {
|
||||
is_row_contiguous &= strides[i] == f_stride || shape[i] == 1;
|
||||
is_col_contiguous &= strides[ri] == b_stride || shape[ri] == 1;
|
||||
is_col_contiguous &= strides[i] == f_stride || shape[i] == 1;
|
||||
is_row_contiguous &= strides[ri] == b_stride || shape[ri] == 1;
|
||||
f_stride *= shape[i];
|
||||
b_stride *= shape[ri];
|
||||
if (strides[i] > 0) {
|
||||
|
@ -546,8 +546,8 @@ void fft_op(
|
||||
auto [data_size, is_row_contiguous, is_col_contiguous] =
|
||||
check_contiguity(x.shape(), strides);
|
||||
|
||||
flags.col_contiguous = is_row_contiguous;
|
||||
flags.row_contiguous = is_col_contiguous;
|
||||
flags.col_contiguous = is_col_contiguous;
|
||||
flags.row_contiguous = is_row_contiguous;
|
||||
flags.contiguous = data_size == x_copy.size();
|
||||
|
||||
x_copy.set_data(
|
||||
|
@ -100,6 +100,36 @@ TEST_CASE("test reshape") {
|
||||
CHECK_EQ(y.strides()[4], 8);
|
||||
// y.strides()[5] can be anything since y.shape()[5] == 1
|
||||
CHECK_EQ(x.data<int32_t>(), y.data<int32_t>());
|
||||
|
||||
// Check contiguity preservation
|
||||
x = ones({10, 10});
|
||||
eval(x);
|
||||
CHECK(x.flags().row_contiguous);
|
||||
CHECK(!x.flags().col_contiguous);
|
||||
y = reshape(x, {2, 5, 10});
|
||||
eval(y);
|
||||
CHECK(y.flags().row_contiguous);
|
||||
CHECK(!y.flags().col_contiguous);
|
||||
y = reshape(x, {10, 1, 10, 1});
|
||||
eval(y);
|
||||
CHECK(y.flags().row_contiguous);
|
||||
CHECK(!y.flags().col_contiguous);
|
||||
x = transpose(x, {1, 0});
|
||||
eval(x);
|
||||
CHECK(!x.flags().row_contiguous);
|
||||
CHECK(x.flags().col_contiguous);
|
||||
y = reshape(x, {2, 5, 10});
|
||||
eval(y);
|
||||
CHECK(!y.flags().row_contiguous);
|
||||
CHECK(y.flags().col_contiguous);
|
||||
y = reshape(x, {2, 50});
|
||||
eval(y);
|
||||
CHECK(y.flags().row_contiguous);
|
||||
CHECK(!y.flags().col_contiguous);
|
||||
y = reshape(x, {10, 1, 10, 1});
|
||||
eval(y);
|
||||
CHECK(!y.flags().row_contiguous);
|
||||
CHECK(y.flags().col_contiguous);
|
||||
}
|
||||
|
||||
TEST_CASE("test flatten") {
|
||||
@ -196,6 +226,32 @@ TEST_CASE("test slice") {
|
||||
|
||||
out = slice(x, {0, 0}, {2, 4}, {1, 2});
|
||||
CHECK(array_equal(out, array({0, 2, 4, 6}, {2, 2})).item<bool>());
|
||||
|
||||
// Check contiguity preservation
|
||||
x = ones({10, 10}) * 2;
|
||||
eval(x);
|
||||
CHECK(x.flags().row_contiguous);
|
||||
CHECK(!x.flags().col_contiguous);
|
||||
out = slice(x, {0, 0}, {10, 5});
|
||||
eval(out);
|
||||
CHECK(!out.flags().row_contiguous);
|
||||
CHECK(!out.flags().col_contiguous);
|
||||
out = slice(x, {0, 0}, {5, 10});
|
||||
eval(out);
|
||||
CHECK(out.flags().row_contiguous);
|
||||
CHECK(!out.flags().col_contiguous);
|
||||
x = transpose(x, {1, 0});
|
||||
eval(x);
|
||||
CHECK(!x.flags().row_contiguous);
|
||||
CHECK(x.flags().col_contiguous);
|
||||
out = slice(x, {0, 0}, {10, 5});
|
||||
eval(out);
|
||||
CHECK(!out.flags().row_contiguous);
|
||||
CHECK(out.flags().col_contiguous);
|
||||
out = slice(x, {0, 0}, {5, 10});
|
||||
eval(out);
|
||||
CHECK(!out.flags().row_contiguous);
|
||||
CHECK(!out.flags().col_contiguous);
|
||||
}
|
||||
|
||||
TEST_CASE("test slice update") {
|
||||
|
Loading…
Reference in New Issue
Block a user