Fix contiguity check (#1336)

Co-authored-by: Alex Barron <abarron22@apple.com>
This commit is contained in:
Angelos Katharopoulos
2024-08-19 16:05:06 -07:00
committed by GitHub
parent f12f24a77c
commit 9d26441224
3 changed files with 60 additions and 4 deletions

View File

@@ -100,6 +100,36 @@ TEST_CASE("test reshape") {
CHECK_EQ(y.strides()[4], 8);
// y.strides()[5] can be anything since y.shape()[5] == 1
CHECK_EQ(x.data<int32_t>(), y.data<int32_t>());
// Check contiguity preservation
x = ones({10, 10});
eval(x);
CHECK(x.flags().row_contiguous);
CHECK(!x.flags().col_contiguous);
y = reshape(x, {2, 5, 10});
eval(y);
CHECK(y.flags().row_contiguous);
CHECK(!y.flags().col_contiguous);
y = reshape(x, {10, 1, 10, 1});
eval(y);
CHECK(y.flags().row_contiguous);
CHECK(!y.flags().col_contiguous);
x = transpose(x, {1, 0});
eval(x);
CHECK(!x.flags().row_contiguous);
CHECK(x.flags().col_contiguous);
y = reshape(x, {2, 5, 10});
eval(y);
CHECK(!y.flags().row_contiguous);
CHECK(y.flags().col_contiguous);
y = reshape(x, {2, 50});
eval(y);
CHECK(y.flags().row_contiguous);
CHECK(!y.flags().col_contiguous);
y = reshape(x, {10, 1, 10, 1});
eval(y);
CHECK(!y.flags().row_contiguous);
CHECK(y.flags().col_contiguous);
}
TEST_CASE("test flatten") {
@@ -196,6 +226,32 @@ TEST_CASE("test slice") {
out = slice(x, {0, 0}, {2, 4}, {1, 2});
CHECK(array_equal(out, array({0, 2, 4, 6}, {2, 2})).item<bool>());
// Check contiguity preservation
x = ones({10, 10}) * 2;
eval(x);
CHECK(x.flags().row_contiguous);
CHECK(!x.flags().col_contiguous);
out = slice(x, {0, 0}, {10, 5});
eval(out);
CHECK(!out.flags().row_contiguous);
CHECK(!out.flags().col_contiguous);
out = slice(x, {0, 0}, {5, 10});
eval(out);
CHECK(out.flags().row_contiguous);
CHECK(!out.flags().col_contiguous);
x = transpose(x, {1, 0});
eval(x);
CHECK(!x.flags().row_contiguous);
CHECK(x.flags().col_contiguous);
out = slice(x, {0, 0}, {10, 5});
eval(out);
CHECK(!out.flags().row_contiguous);
CHECK(out.flags().col_contiguous);
out = slice(x, {0, 0}, {5, 10});
eval(out);
CHECK(!out.flags().row_contiguous);
CHECK(!out.flags().col_contiguous);
}
TEST_CASE("test slice update") {