mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
Fix slice data size (#1394)
* fix slice data size and add tests * fix contiguous flag * simplify stride and perform copy for non-contiguous arrays * fix cpu * comment
This commit is contained in:
@@ -228,7 +228,7 @@ TEST_CASE("test slice") {
|
||||
CHECK(array_equal(out, array({0, 2, 4, 6}, {2, 2})).item<bool>());
|
||||
|
||||
// Check contiguity preservation
|
||||
x = ones({10, 10}) * 2;
|
||||
x = ones({10, 10});
|
||||
eval(x);
|
||||
CHECK(x.flags().row_contiguous);
|
||||
CHECK(!x.flags().col_contiguous);
|
||||
@@ -252,6 +252,59 @@ TEST_CASE("test slice") {
|
||||
eval(out);
|
||||
CHECK(!out.flags().row_contiguous);
|
||||
CHECK(!out.flags().col_contiguous);
|
||||
|
||||
x = ones({6, 4, 10});
|
||||
out = slice(x, {0, 0, 0}, {6, 4, 10}, {2, 1, 2});
|
||||
eval(out);
|
||||
CHECK(!out.flags().contiguous);
|
||||
CHECK(!out.flags().row_contiguous);
|
||||
CHECK(!out.flags().col_contiguous);
|
||||
|
||||
// Check data size correctness
|
||||
x = ones({4});
|
||||
out = slice(x, {0}, {2});
|
||||
eval(out);
|
||||
CHECK_EQ(out.data_size(), 2);
|
||||
|
||||
out = slice(x, {2}, {4});
|
||||
eval(out);
|
||||
CHECK_EQ(out.data_size(), 2);
|
||||
|
||||
out = slice(x, {0}, {4}, {2});
|
||||
eval(out);
|
||||
CHECK_EQ(out.data_size(), 4);
|
||||
|
||||
x = ones({4, 4});
|
||||
out = slice(x, {0, 0}, {2, 4});
|
||||
eval(out);
|
||||
CHECK_EQ(out.data_size(), 8);
|
||||
|
||||
out = slice(x, {0, 0}, {1, 2});
|
||||
eval(out);
|
||||
CHECK_EQ(out.data_size(), 2);
|
||||
|
||||
out = slice(x, {0, 1}, {4, 4});
|
||||
eval(out);
|
||||
CHECK_EQ(out.data_size(), 15);
|
||||
|
||||
out = slice(x, {1, 2}, {3, 4});
|
||||
eval(out);
|
||||
CHECK_EQ(out.data_size(), 6);
|
||||
|
||||
x = ones({4, 4, 4});
|
||||
out = slice(x, {0, 0, 0}, {4, 2, 2});
|
||||
eval(out);
|
||||
CHECK_EQ(out.data_size(), 54);
|
||||
|
||||
x = ones({4, 4, 4});
|
||||
out = slice(x, {2, 2, 2}, {3, 3, 3});
|
||||
eval(out);
|
||||
CHECK_EQ(out.data_size(), 1);
|
||||
|
||||
x = ones({4, 4, 4});
|
||||
out = slice(x, {2, 2, 2}, {3, 4, 3});
|
||||
eval(out);
|
||||
CHECK_EQ(out.data_size(), 5);
|
||||
}
|
||||
|
||||
TEST_CASE("test slice update") {
|
||||
|
Reference in New Issue
Block a user