mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-20 01:18:12 +08:00
@@ -100,7 +100,7 @@ TEST_CASE("[mlx.core.linalg.norm] double ord") {
|
||||
norm(x, -std::numeric_limits<double>::infinity()).item<float>(),
|
||||
doctest::Approx(expected));
|
||||
|
||||
x = reshape(arange(9), {3, 3});
|
||||
x = reshape(arange(9, float32), {3, 3});
|
||||
|
||||
CHECK(allclose(
|
||||
norm(x, 2.0, 0, false),
|
||||
@@ -129,10 +129,34 @@ TEST_CASE("[mlx.core.linalg.norm] double ord") {
|
||||
CHECK_EQ(
|
||||
norm(x, -1.0, std::vector<int>{1, 0}).item<float>(),
|
||||
doctest::Approx(3.0));
|
||||
CHECK_EQ(
|
||||
norm(x, 2.0, std::vector<int>{0, 1}, false, Device::cpu).item<float>(),
|
||||
doctest::Approx(14.226707));
|
||||
CHECK_EQ(
|
||||
norm(x, 2.0, std::vector<int>{1, 0}, false, Device::cpu).item<float>(),
|
||||
doctest::Approx(14.226707));
|
||||
CHECK_EQ(
|
||||
norm(x, -2.0, std::vector<int>{0, 1}, false, Device::cpu).item<float>(),
|
||||
doctest::Approx(0.0));
|
||||
CHECK_EQ(
|
||||
norm(x, -2.0, std::vector<int>{1, 0}, false, Device::cpu).item<float>(),
|
||||
doctest::Approx(0.0));
|
||||
CHECK_EQ(norm(x, 1.0, std::vector<int>{0, 1}, true).shape(), Shape{1, 1});
|
||||
CHECK_EQ(norm(x, 1.0, std::vector<int>{1, 0}, true).shape(), Shape{1, 1});
|
||||
CHECK_EQ(norm(x, -1.0, std::vector<int>{0, 1}, true).shape(), Shape{1, 1});
|
||||
CHECK_EQ(norm(x, -1.0, std::vector<int>{1, 0}, true).shape(), Shape{1, 1});
|
||||
CHECK_EQ(
|
||||
norm(x, 2.0, std::vector<int>{0, 1}, true, Device::cpu).shape(),
|
||||
Shape{1, 1});
|
||||
CHECK_EQ(
|
||||
norm(x, 2.0, std::vector<int>{1, 0}, true, Device::cpu).shape(),
|
||||
Shape{1, 1});
|
||||
CHECK_EQ(
|
||||
norm(x, -2.0, std::vector<int>{0, 1}, true, Device::cpu).shape(),
|
||||
Shape{1, 1});
|
||||
CHECK_EQ(
|
||||
norm(x, -2.0, std::vector<int>{1, 0}, true, Device::cpu).shape(),
|
||||
Shape{1, 1});
|
||||
|
||||
CHECK_EQ(
|
||||
norm(x, -1.0, std::vector<int>{-2, -1}, false).item<float>(),
|
||||
@@ -140,8 +164,14 @@ TEST_CASE("[mlx.core.linalg.norm] double ord") {
|
||||
CHECK_EQ(
|
||||
norm(x, 1.0, std::vector<int>{-2, -1}, false).item<float>(),
|
||||
doctest::Approx(15.0));
|
||||
CHECK_EQ(
|
||||
norm(x, -2.0, std::vector<int>{-2, -1}, false, Device::cpu).item<float>(),
|
||||
doctest::Approx(0.0));
|
||||
CHECK_EQ(
|
||||
norm(x, 2.0, std::vector<int>{-2, -1}, false, Device::cpu).item<float>(),
|
||||
doctest::Approx(14.226707));
|
||||
|
||||
x = reshape(arange(18), {2, 3, 3});
|
||||
x = reshape(arange(18, float32), {2, 3, 3});
|
||||
CHECK_THROWS(norm(x, 2.0, std::vector{0, 1, 2}));
|
||||
CHECK(allclose(
|
||||
norm(x, 3.0, 0),
|
||||
@@ -199,13 +229,31 @@ TEST_CASE("[mlx.core.linalg.norm] double ord") {
|
||||
.item<bool>());
|
||||
CHECK(allclose(norm(x, -1.0, std::vector<int>{1, 2}), array({9, 36}))
|
||||
.item<bool>());
|
||||
CHECK(allclose(
|
||||
norm(x, 2.0, std::vector<int>{0, 1}, false, Device::cpu),
|
||||
array({22.045408, 24.155825, 26.318918}))
|
||||
.item<bool>());
|
||||
CHECK(allclose(
|
||||
norm(x, 2.0, std::vector<int>{1, 2}, false, Device::cpu),
|
||||
array({14.226707, 39.759212}))
|
||||
.item<bool>());
|
||||
CHECK(allclose(
|
||||
norm(x, -2.0, std::vector<int>{0, 1}, false, Device::cpu),
|
||||
array({3, 2.7378995, 2.5128777}))
|
||||
.item<bool>());
|
||||
CHECK(allclose(
|
||||
norm(x, -2.0, std::vector<int>{1, 2}, false, Device::cpu),
|
||||
array({4.979028e-16, 7.009628e-16}),
|
||||
/* rtol = */ 1e-5,
|
||||
/* atol = */ 1e-6)
|
||||
.item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("[mlx.core.linalg.norm] string ord") {
|
||||
array x({1, 2, 3});
|
||||
CHECK_THROWS(norm(x, "fro"));
|
||||
|
||||
x = reshape(arange(9), {3, 3});
|
||||
x = reshape(arange(9, float32), {3, 3});
|
||||
CHECK_THROWS(norm(x, "bad ord"));
|
||||
|
||||
CHECK_EQ(
|
||||
@@ -214,8 +262,11 @@ TEST_CASE("[mlx.core.linalg.norm] string ord") {
|
||||
CHECK_EQ(
|
||||
norm(x, "fro", std::vector<int>{0, 1}).item<float>(),
|
||||
doctest::Approx(14.2828568570857));
|
||||
CHECK_EQ(
|
||||
norm(x, "nuc", std::vector<int>{0, 1}, false, Device::cpu).item<float>(),
|
||||
doctest::Approx(15.491934));
|
||||
|
||||
x = reshape(arange(18), {2, 3, 3});
|
||||
x = reshape(arange(18, float32), {2, 3, 3});
|
||||
CHECK(allclose(
|
||||
norm(x, "fro", std::vector<int>{0, 1}),
|
||||
array({22.24859546, 24.31049156, 26.43860813}))
|
||||
@@ -240,6 +291,18 @@ TEST_CASE("[mlx.core.linalg.norm] string ord") {
|
||||
norm(x, "f", std::vector<int>{2, 1}),
|
||||
array({14.28285686, 39.7617907}))
|
||||
.item<bool>());
|
||||
CHECK(allclose(
|
||||
norm(x, "nuc", std::vector<int>{0, 1}, false, Device::cpu),
|
||||
array({25.045408, 26.893724, 28.831797}))
|
||||
.item<bool>());
|
||||
CHECK(allclose(
|
||||
norm(x, "nuc", std::vector<int>{1, 2}, false, Device::cpu),
|
||||
array({15.491934, 40.211937}))
|
||||
.item<bool>());
|
||||
CHECK(allclose(
|
||||
norm(x, "nuc", std::vector<int>{-2, -1}, false, Device::cpu),
|
||||
array({15.491934, 40.211937}))
|
||||
.item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test QR factorization") {
|
||||
@@ -271,7 +334,7 @@ TEST_CASE("test SVD factorization") {
|
||||
|
||||
const auto prng_key = random::key(42);
|
||||
const auto A = mlx::core::random::normal({5, 4}, prng_key);
|
||||
const auto outs = linalg::svd(A, Device::cpu);
|
||||
const auto outs = linalg::svd(A, true, Device::cpu);
|
||||
CHECK_EQ(outs.size(), 3);
|
||||
|
||||
const auto& U = outs[0];
|
||||
@@ -291,6 +354,15 @@ TEST_CASE("test SVD factorization") {
|
||||
CHECK_EQ(U.dtype(), float32);
|
||||
CHECK_EQ(S.dtype(), float32);
|
||||
CHECK_EQ(Vt.dtype(), float32);
|
||||
|
||||
// Test singular values
|
||||
const auto& outs_sv = linalg::svd(A, false, Device::cpu);
|
||||
const auto SV = outs_sv[0];
|
||||
|
||||
CHECK_EQ(SV.shape(), Shape{4});
|
||||
CHECK_EQ(SV.dtype(), float32);
|
||||
|
||||
CHECK(allclose(norm(SV), norm(A, "fro")).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test matrix inversion") {
|
||||
|
Reference in New Issue
Block a user