mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 04:24:36 +08:00
Reshape improvement (#818)
This commit is contained in:

committed by
GitHub

parent
5ad133f8bb
commit
29d0c10ee5
@@ -56,6 +56,50 @@ TEST_CASE("test reshape") {
|
||||
CHECK_THROWS_AS(reshape(x, {1}), std::invalid_argument);
|
||||
y = reshape(x, {1, 5, 0});
|
||||
CHECK_EQ(y.shape(), std::vector<int>{1, 5, 0});
|
||||
|
||||
// Check that reshaping a transposed array doesn't result in a copy
|
||||
x = reshape(arange(64), {2, 4, 8});
|
||||
x.eval();
|
||||
CHECK_EQ(x.strides()[0], 32);
|
||||
CHECK_EQ(x.strides()[1], 8);
|
||||
CHECK_EQ(x.strides()[2], 1);
|
||||
y = reshape(transpose(x, {0, 2, 1}), {2, 4, 2, 4});
|
||||
y.eval();
|
||||
CHECK_EQ(y.strides()[0], 32);
|
||||
CHECK_EQ(y.strides()[1], 2);
|
||||
CHECK_EQ(y.strides()[2], 1);
|
||||
CHECK_EQ(y.strides()[3], 8);
|
||||
CHECK_EQ(x.data<int32_t>(), y.data<int32_t>());
|
||||
|
||||
// Split transposed (2, 8, 4) -> (2, 8, 2, 2)
|
||||
y = reshape(transpose(x, {0, 2, 1}), {2, 8, 2, 2});
|
||||
y.eval();
|
||||
CHECK_EQ(y.strides()[0], 32);
|
||||
CHECK_EQ(y.strides()[1], 1);
|
||||
CHECK_EQ(y.strides()[2], 16);
|
||||
CHECK_EQ(y.strides()[3], 8);
|
||||
CHECK_EQ(x.data<int32_t>(), y.data<int32_t>());
|
||||
|
||||
// Split transposed (2, 8, 4) -> (2, 8, 2, 1, 2)
|
||||
y = reshape(transpose(x, {0, 2, 1}), {2, 8, 2, 1, 2});
|
||||
y.eval();
|
||||
CHECK_EQ(y.strides()[0], 32);
|
||||
CHECK_EQ(y.strides()[1], 1);
|
||||
CHECK_EQ(y.strides()[2], 16);
|
||||
// y.strides()[3] can be anything since y.shape()[3] == 1
|
||||
CHECK_EQ(y.strides()[4], 8);
|
||||
CHECK_EQ(x.data<int32_t>(), y.data<int32_t>());
|
||||
|
||||
// Split transposed (2, 8, 4) -> (2, 8, 2, 1, 2, 1)
|
||||
y = reshape(transpose(x, {0, 2, 1}), {2, 8, 2, 1, 2, 1});
|
||||
y.eval();
|
||||
CHECK_EQ(y.strides()[0], 32);
|
||||
CHECK_EQ(y.strides()[1], 1);
|
||||
CHECK_EQ(y.strides()[2], 16);
|
||||
// y.strides()[3] can be anything since y.shape()[3] == 1
|
||||
CHECK_EQ(y.strides()[4], 8);
|
||||
// y.strides()[5] can be anything since y.shape()[5] == 1
|
||||
CHECK_EQ(x.data<int32_t>(), y.data<int32_t>());
|
||||
}
|
||||
|
||||
TEST_CASE("test flatten") {
|
||||
|
Reference in New Issue
Block a user