mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-01 08:38:12 +08:00
Use int64 stride everywhere (#1671)
* use int64 stride everywhere * fix ext * fix ext * more shape + cleanup * one more * few more
This commit is contained in:
@@ -342,9 +342,9 @@ TEST_CASE("test vmap gather") {
|
||||
auto x = zeros({2, 2, 2, 2});
|
||||
auto y = array({0, 1, 0, 0, 1, 0}, {2, 3});
|
||||
auto out = vmap(fun, {0, -1})({x, y})[0];
|
||||
CHECK_EQ(out.shape(), std::vector<int>{2, 2, 3, 2, 2});
|
||||
CHECK_EQ(out.shape(), Shape{2, 2, 3, 2, 2});
|
||||
out = vmap(fun, {0, -1}, {3})({x, y})[0];
|
||||
CHECK_EQ(out.shape(), std::vector<int>{2, 3, 2, 2, 2});
|
||||
CHECK_EQ(out.shape(), Shape{2, 3, 2, 2, 2});
|
||||
}
|
||||
|
||||
{
|
||||
@@ -358,7 +358,7 @@ TEST_CASE("test vmap gather") {
|
||||
auto x = zeros({2, 2, 2, 2});
|
||||
auto y = array({0, 1, 0, 0, 1, 0}, {2, 3});
|
||||
auto out = vmap(fun, {0, 0})({x, y})[0];
|
||||
CHECK_EQ(out.shape(), std::vector<int>{2, 3, 2, 2});
|
||||
CHECK_EQ(out.shape(), Shape{2, 3, 2, 2});
|
||||
}
|
||||
|
||||
{
|
||||
@@ -373,7 +373,7 @@ TEST_CASE("test vmap gather") {
|
||||
auto y = array({0, 1, 0, 0, 1, 0}, {2, 3});
|
||||
|
||||
auto out = vmap(fun, {-1, 0})({x, y})[0];
|
||||
CHECK_EQ(out.shape(), std::vector<int>{2, 3, 2, 2, 2});
|
||||
CHECK_EQ(out.shape(), Shape{2, 3, 2, 2, 2});
|
||||
}
|
||||
|
||||
{
|
||||
@@ -388,11 +388,11 @@ TEST_CASE("test vmap gather") {
|
||||
auto y = array({0, 1, 0, 0, 1, 0}, {2, 3});
|
||||
auto z = array({0, 1, 0, 0, 1, 0}, {2, 3});
|
||||
auto out = vmap(fun, {-1, 0, 0})({x, y, z})[0];
|
||||
CHECK_EQ(out.shape(), std::vector<int>{2, 3, 2, 2});
|
||||
CHECK_EQ(out.shape(), Shape{2, 3, 2, 2});
|
||||
|
||||
z = array({0, 1, 0, 0, 1, 0}, {3, 2});
|
||||
out = vmap(fun, {-1, 0, 1})({x, y, z})[0];
|
||||
CHECK_EQ(out.shape(), std::vector<int>{2, 3, 2, 2});
|
||||
CHECK_EQ(out.shape(), Shape{2, 3, 2, 2});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -483,9 +483,9 @@ TEST_CASE("test vmap SVD") {
|
||||
const auto& S = out.at(1);
|
||||
const auto& Vt = out.at(2);
|
||||
|
||||
CHECK_EQ(U.shape(), std::vector<int>{a.shape(1), a.shape(0), a.shape(0)});
|
||||
CHECK_EQ(S.shape(), std::vector<int>{a.shape(1), a.shape(2)});
|
||||
CHECK_EQ(Vt.shape(), std::vector<int>{a.shape(1), a.shape(2), a.shape(2)});
|
||||
CHECK_EQ(U.shape(), Shape{a.shape(1), a.shape(0), a.shape(0)});
|
||||
CHECK_EQ(S.shape(), Shape{a.shape(1), a.shape(2)});
|
||||
CHECK_EQ(Vt.shape(), Shape{a.shape(1), a.shape(2), a.shape(2)});
|
||||
}
|
||||
|
||||
// vmap over the third axis.
|
||||
@@ -495,8 +495,8 @@ TEST_CASE("test vmap SVD") {
|
||||
const auto& S = out.at(1);
|
||||
const auto& Vt = out.at(2);
|
||||
|
||||
CHECK_EQ(U.shape(), std::vector<int>{a.shape(2), a.shape(0), a.shape(0)});
|
||||
CHECK_EQ(S.shape(), std::vector<int>{a.shape(2), a.shape(0)});
|
||||
CHECK_EQ(Vt.shape(), std::vector<int>{a.shape(2), a.shape(1), a.shape(1)});
|
||||
CHECK_EQ(U.shape(), Shape{a.shape(2), a.shape(0), a.shape(0)});
|
||||
CHECK_EQ(S.shape(), Shape{a.shape(2), a.shape(0)});
|
||||
CHECK_EQ(Vt.shape(), Shape{a.shape(2), a.shape(1), a.shape(1)});
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user