Fix slice data size (#1394)

* fix slice data size and add tests

* fix contiguous flag

* simplify stride and perform copy for non-contiguous arrays

* fix cpu

* comment
This commit is contained in:
Awni Hannun
2024-09-04 19:10:43 -07:00
committed by GitHub
parent 11371fe251
commit 7cca1727af
12 changed files with 129 additions and 39 deletions

View File

@@ -228,7 +228,7 @@ TEST_CASE("test slice") {
CHECK(array_equal(out, array({0, 2, 4, 6}, {2, 2})).item<bool>());
// Check contiguity preservation
x = ones({10, 10}) * 2;
x = ones({10, 10});
eval(x);
CHECK(x.flags().row_contiguous);
CHECK(!x.flags().col_contiguous);
@@ -252,6 +252,59 @@ TEST_CASE("test slice") {
eval(out);
CHECK(!out.flags().row_contiguous);
CHECK(!out.flags().col_contiguous);
x = ones({6, 4, 10});
out = slice(x, {0, 0, 0}, {6, 4, 10}, {2, 1, 2});
eval(out);
CHECK(!out.flags().contiguous);
CHECK(!out.flags().row_contiguous);
CHECK(!out.flags().col_contiguous);
// Check data size correctness
x = ones({4});
out = slice(x, {0}, {2});
eval(out);
CHECK_EQ(out.data_size(), 2);
out = slice(x, {2}, {4});
eval(out);
CHECK_EQ(out.data_size(), 2);
out = slice(x, {0}, {4}, {2});
eval(out);
CHECK_EQ(out.data_size(), 4);
x = ones({4, 4});
out = slice(x, {0, 0}, {2, 4});
eval(out);
CHECK_EQ(out.data_size(), 8);
out = slice(x, {0, 0}, {1, 2});
eval(out);
CHECK_EQ(out.data_size(), 2);
out = slice(x, {0, 1}, {4, 4});
eval(out);
CHECK_EQ(out.data_size(), 15);
out = slice(x, {1, 2}, {3, 4});
eval(out);
CHECK_EQ(out.data_size(), 6);
x = ones({4, 4, 4});
out = slice(x, {0, 0, 0}, {4, 2, 2});
eval(out);
CHECK_EQ(out.data_size(), 54);
x = ones({4, 4, 4});
out = slice(x, {2, 2, 2}, {3, 3, 3});
eval(out);
CHECK_EQ(out.data_size(), 1);
x = ones({4, 4, 4});
out = slice(x, {2, 2, 2}, {3, 4, 3});
eval(out);
CHECK_EQ(out.data_size(), 5);
}
TEST_CASE("test slice update") {