Reshape improvement (#818)

This commit is contained in:
Angelos Katharopoulos
2024-03-12 17:54:31 -07:00
committed by GitHub
parent 5ad133f8bb
commit 29d0c10ee5
8 changed files with 199 additions and 84 deletions

View File

@@ -56,6 +56,50 @@ TEST_CASE("test reshape") {
CHECK_THROWS_AS(reshape(x, {1}), std::invalid_argument);
y = reshape(x, {1, 5, 0});
CHECK_EQ(y.shape(), std::vector<int>{1, 5, 0});
// Check that reshaping a transposed array doesn't result in a copy
x = reshape(arange(64), {2, 4, 8});
x.eval();
CHECK_EQ(x.strides()[0], 32);
CHECK_EQ(x.strides()[1], 8);
CHECK_EQ(x.strides()[2], 1);
y = reshape(transpose(x, {0, 2, 1}), {2, 4, 2, 4});
y.eval();
CHECK_EQ(y.strides()[0], 32);
CHECK_EQ(y.strides()[1], 2);
CHECK_EQ(y.strides()[2], 1);
CHECK_EQ(y.strides()[3], 8);
CHECK_EQ(x.data<int32_t>(), y.data<int32_t>());
// Split transposed (2, 8, 4) -> (2, 8, 2, 2)
y = reshape(transpose(x, {0, 2, 1}), {2, 8, 2, 2});
y.eval();
CHECK_EQ(y.strides()[0], 32);
CHECK_EQ(y.strides()[1], 1);
CHECK_EQ(y.strides()[2], 16);
CHECK_EQ(y.strides()[3], 8);
CHECK_EQ(x.data<int32_t>(), y.data<int32_t>());
// Split transposed (2, 8, 4) -> (2, 8, 2, 1, 2)
y = reshape(transpose(x, {0, 2, 1}), {2, 8, 2, 1, 2});
y.eval();
CHECK_EQ(y.strides()[0], 32);
CHECK_EQ(y.strides()[1], 1);
CHECK_EQ(y.strides()[2], 16);
// y.strides()[3] can be anything since y.shape()[3] == 1
CHECK_EQ(y.strides()[4], 8);
CHECK_EQ(x.data<int32_t>(), y.data<int32_t>());
// Split transposed (2, 8, 4) -> (2, 8, 2, 1, 2, 1)
y = reshape(transpose(x, {0, 2, 1}), {2, 8, 2, 1, 2, 1});
y.eval();
CHECK_EQ(y.strides()[0], 32);
CHECK_EQ(y.strides()[1], 1);
CHECK_EQ(y.strides()[2], 16);
// y.strides()[3] can be anything since y.shape()[3] == 1
CHECK_EQ(y.strides()[4], 8);
// y.strides()[5] can be anything since y.shape()[5] == 1
CHECK_EQ(x.data<int32_t>(), y.data<int32_t>());
}
TEST_CASE("test flatten") {