This commit is contained in:
Awni Hannun
2024-12-06 16:50:08 -08:00
parent ee59d50293
commit 2b9c24c517
4 changed files with 130 additions and 30 deletions

View File

@@ -4,6 +4,7 @@
#define _USE_MATH_DEFINES
#include <cmath>
#include <iostream> // TODO
#include <numeric>
#include "doctest/doctest.h"
@@ -3777,6 +3778,10 @@ TEST_CASE("test dynamic reshape") {
// Bad character
CHECK_THROWS(dynamic_reshape(x, {"&", 1, 1}));
// Malformed
CHECK_THROWS(dynamic_reshape(x, {"+a", 1, 1}));
CHECK_THROWS(dynamic_reshape(x, {"a+", 1, 1}));
// No dim in string
CHECK_THROWS(dynamic_reshape(x, {"1", 1, 1}));
@@ -3785,7 +3790,41 @@ TEST_CASE("test dynamic reshape") {
// Too many dims
CHECK_THROWS(dynamic_reshape(x, {"a", "b", "c", "d"}));
CHECK_THROWS(dynamic_reshape(x, {"abcd", 1, 1}));
// Too many inferred dims
CHECK_THROWS(dynamic_reshape(x, {"a", -1, -1}));
// Bad sizes
x = zeros({2, 2, 2});
CHECK_THROWS_AS(dynamic_reshape(x, {7}), std::invalid_argument);
CHECK_THROWS_AS(dynamic_reshape(x, {-1, 7}), std::invalid_argument);
// Works with empty array
x = array({});
auto y = dynamic_reshape(x, {0, 0, 0});
CHECK_EQ(y.shape(), Shape{0, 0, 0});
CHECK_THROWS_AS(dynamic_reshape(x, {}), std::invalid_argument);
CHECK_THROWS_AS(dynamic_reshape(x, {1}), std::invalid_argument);
y = dynamic_reshape(x, {1, 5, 0});
CHECK_EQ(y.shape(), Shape{1, 5, 0});
x = array({1, 2, 3});
y = dynamic_reshape(x, {"a", 1, 1});
CHECK_EQ(y.shape(), Shape{3, 1, 1});
x = zeros({2, 2});
y = dynamic_reshape(x, {"a*b"});
CHECK_EQ(y.shape(), Shape{4});
y = dynamic_reshape(x, {"2*a"});
CHECK_EQ(y.shape(), Shape{4});
x = zeros({2, 20});
y = dynamic_reshape(x, {"a*20"});
CHECK_EQ(y.shape(), Shape{40});
x = zeros({2, 20});
y = dynamic_reshape(x, {"a", "b/10", 10});
CHECK_EQ(y.shape(), Shape{2, 2, 10});
}