mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-20 01:18:12 +08:00
Use int64 stride everywhere (#1671)
* use int64 stride everywhere * fix ext * fix ext * more shape + cleanup * one more * few more
This commit is contained in:
@@ -129,18 +129,10 @@ TEST_CASE("[mlx.core.linalg.norm] double ord") {
|
||||
CHECK_EQ(
|
||||
norm(x, -1.0, std::vector<int>{1, 0}).item<float>(),
|
||||
doctest::Approx(3.0));
|
||||
CHECK_EQ(
|
||||
norm(x, 1.0, std::vector<int>{0, 1}, true).shape(),
|
||||
std::vector<int>{1, 1});
|
||||
CHECK_EQ(
|
||||
norm(x, 1.0, std::vector<int>{1, 0}, true).shape(),
|
||||
std::vector<int>{1, 1});
|
||||
CHECK_EQ(
|
||||
norm(x, -1.0, std::vector<int>{0, 1}, true).shape(),
|
||||
std::vector<int>{1, 1});
|
||||
CHECK_EQ(
|
||||
norm(x, -1.0, std::vector<int>{1, 0}, true).shape(),
|
||||
std::vector<int>{1, 1});
|
||||
CHECK_EQ(norm(x, 1.0, std::vector<int>{0, 1}, true).shape(), Shape{1, 1});
|
||||
CHECK_EQ(norm(x, 1.0, std::vector<int>{1, 0}, true).shape(), Shape{1, 1});
|
||||
CHECK_EQ(norm(x, -1.0, std::vector<int>{0, 1}, true).shape(), Shape{1, 1});
|
||||
CHECK_EQ(norm(x, -1.0, std::vector<int>{1, 0}, true).shape(), Shape{1, 1});
|
||||
|
||||
CHECK_EQ(
|
||||
norm(x, -1.0, std::vector<int>{-2, -1}, false).item<float>(),
|
||||
@@ -286,9 +278,9 @@ TEST_CASE("test SVD factorization") {
|
||||
const auto& S = outs[1];
|
||||
const auto& Vt = outs[2];
|
||||
|
||||
CHECK_EQ(U.shape(), std::vector<int>{5, 5});
|
||||
CHECK_EQ(S.shape(), std::vector<int>{4});
|
||||
CHECK_EQ(Vt.shape(), std::vector<int>{4, 4});
|
||||
CHECK_EQ(U.shape(), Shape{5, 5});
|
||||
CHECK_EQ(S.shape(), Shape{4});
|
||||
CHECK_EQ(Vt.shape(), Shape{4, 4});
|
||||
|
||||
const auto U_slice = slice(U, {0, 0}, {U.shape(0), S.shape(0)});
|
||||
|
||||
|
Reference in New Issue
Block a user