Use int64 stride everywhere (#1671)

* use int64 stride everywhere

* fix ext

* fix ext

* more shape + cleanup

* one more

* few more
This commit is contained in:
Awni Hannun
2024-12-09 11:09:02 -08:00
committed by GitHub
parent 35b412c099
commit 40c62c1321
102 changed files with 1262 additions and 1705 deletions

View File

@@ -342,9 +342,9 @@ TEST_CASE("test vmap gather") {
auto x = zeros({2, 2, 2, 2});
auto y = array({0, 1, 0, 0, 1, 0}, {2, 3});
auto out = vmap(fun, {0, -1})({x, y})[0];
CHECK_EQ(out.shape(), std::vector<int>{2, 2, 3, 2, 2});
CHECK_EQ(out.shape(), Shape{2, 2, 3, 2, 2});
out = vmap(fun, {0, -1}, {3})({x, y})[0];
CHECK_EQ(out.shape(), std::vector<int>{2, 3, 2, 2, 2});
CHECK_EQ(out.shape(), Shape{2, 3, 2, 2, 2});
}
{
@@ -358,7 +358,7 @@ TEST_CASE("test vmap gather") {
auto x = zeros({2, 2, 2, 2});
auto y = array({0, 1, 0, 0, 1, 0}, {2, 3});
auto out = vmap(fun, {0, 0})({x, y})[0];
CHECK_EQ(out.shape(), std::vector<int>{2, 3, 2, 2});
CHECK_EQ(out.shape(), Shape{2, 3, 2, 2});
}
{
@@ -373,7 +373,7 @@ TEST_CASE("test vmap gather") {
auto y = array({0, 1, 0, 0, 1, 0}, {2, 3});
auto out = vmap(fun, {-1, 0})({x, y})[0];
CHECK_EQ(out.shape(), std::vector<int>{2, 3, 2, 2, 2});
CHECK_EQ(out.shape(), Shape{2, 3, 2, 2, 2});
}
{
@@ -388,11 +388,11 @@ TEST_CASE("test vmap gather") {
auto y = array({0, 1, 0, 0, 1, 0}, {2, 3});
auto z = array({0, 1, 0, 0, 1, 0}, {2, 3});
auto out = vmap(fun, {-1, 0, 0})({x, y, z})[0];
CHECK_EQ(out.shape(), std::vector<int>{2, 3, 2, 2});
CHECK_EQ(out.shape(), Shape{2, 3, 2, 2});
z = array({0, 1, 0, 0, 1, 0}, {3, 2});
out = vmap(fun, {-1, 0, 1})({x, y, z})[0];
CHECK_EQ(out.shape(), std::vector<int>{2, 3, 2, 2});
CHECK_EQ(out.shape(), Shape{2, 3, 2, 2});
}
}
@@ -483,9 +483,9 @@ TEST_CASE("test vmap SVD") {
const auto& S = out.at(1);
const auto& Vt = out.at(2);
CHECK_EQ(U.shape(), std::vector<int>{a.shape(1), a.shape(0), a.shape(0)});
CHECK_EQ(S.shape(), std::vector<int>{a.shape(1), a.shape(2)});
CHECK_EQ(Vt.shape(), std::vector<int>{a.shape(1), a.shape(2), a.shape(2)});
CHECK_EQ(U.shape(), Shape{a.shape(1), a.shape(0), a.shape(0)});
CHECK_EQ(S.shape(), Shape{a.shape(1), a.shape(2)});
CHECK_EQ(Vt.shape(), Shape{a.shape(1), a.shape(2), a.shape(2)});
}
// vmap over the third axis.
@@ -495,8 +495,8 @@ TEST_CASE("test vmap SVD") {
const auto& S = out.at(1);
const auto& Vt = out.at(2);
CHECK_EQ(U.shape(), std::vector<int>{a.shape(2), a.shape(0), a.shape(0)});
CHECK_EQ(S.shape(), std::vector<int>{a.shape(2), a.shape(0)});
CHECK_EQ(Vt.shape(), std::vector<int>{a.shape(2), a.shape(1), a.shape(1)});
CHECK_EQ(U.shape(), Shape{a.shape(2), a.shape(0), a.shape(0)});
CHECK_EQ(S.shape(), Shape{a.shape(2), a.shape(0)});
CHECK_EQ(Vt.shape(), Shape{a.shape(2), a.shape(1), a.shape(1)});
}
}