Use int64 stride everywhere (#1671)

* use int64 stride everywhere

* fix ext

* fix ext

* more shape + cleanup

* one more

* few more
This commit is contained in:
Awni Hannun
2024-12-09 11:09:02 -08:00
committed by GitHub
parent 35b412c099
commit 40c62c1321
102 changed files with 1262 additions and 1705 deletions

View File

@@ -129,18 +129,10 @@ TEST_CASE("[mlx.core.linalg.norm] double ord") {
CHECK_EQ(
norm(x, -1.0, std::vector<int>{1, 0}).item<float>(),
doctest::Approx(3.0));
CHECK_EQ(
norm(x, 1.0, std::vector<int>{0, 1}, true).shape(),
std::vector<int>{1, 1});
CHECK_EQ(
norm(x, 1.0, std::vector<int>{1, 0}, true).shape(),
std::vector<int>{1, 1});
CHECK_EQ(
norm(x, -1.0, std::vector<int>{0, 1}, true).shape(),
std::vector<int>{1, 1});
CHECK_EQ(
norm(x, -1.0, std::vector<int>{1, 0}, true).shape(),
std::vector<int>{1, 1});
CHECK_EQ(norm(x, 1.0, std::vector<int>{0, 1}, true).shape(), Shape{1, 1});
CHECK_EQ(norm(x, 1.0, std::vector<int>{1, 0}, true).shape(), Shape{1, 1});
CHECK_EQ(norm(x, -1.0, std::vector<int>{0, 1}, true).shape(), Shape{1, 1});
CHECK_EQ(norm(x, -1.0, std::vector<int>{1, 0}, true).shape(), Shape{1, 1});
CHECK_EQ(
norm(x, -1.0, std::vector<int>{-2, -1}, false).item<float>(),
@@ -286,9 +278,9 @@ TEST_CASE("test SVD factorization") {
const auto& S = outs[1];
const auto& Vt = outs[2];
CHECK_EQ(U.shape(), std::vector<int>{5, 5});
CHECK_EQ(S.shape(), std::vector<int>{4});
CHECK_EQ(Vt.shape(), std::vector<int>{4, 4});
CHECK_EQ(U.shape(), Shape{5, 5});
CHECK_EQ(S.shape(), Shape{4});
CHECK_EQ(Vt.shape(), Shape{4, 4});
const auto U_slice = slice(U, {0, 0}, {U.shape(0), S.shape(0)});