Adds nuclear norm support (#1894)

* adjust norm unit test tolerance
This commit is contained in:
Abe Leininger
2025-03-04 15:26:02 -06:00
committed by GitHub
parent 9680f72cca
commit 3835a428c5
11 changed files with 260 additions and 55 deletions

View File

@@ -100,7 +100,7 @@ TEST_CASE("[mlx.core.linalg.norm] double ord") {
norm(x, -std::numeric_limits<double>::infinity()).item<float>(),
doctest::Approx(expected));
x = reshape(arange(9), {3, 3});
x = reshape(arange(9, float32), {3, 3});
CHECK(allclose(
norm(x, 2.0, 0, false),
@@ -129,10 +129,34 @@ TEST_CASE("[mlx.core.linalg.norm] double ord") {
CHECK_EQ(
norm(x, -1.0, std::vector<int>{1, 0}).item<float>(),
doctest::Approx(3.0));
CHECK_EQ(
norm(x, 2.0, std::vector<int>{0, 1}, false, Device::cpu).item<float>(),
doctest::Approx(14.226707));
CHECK_EQ(
norm(x, 2.0, std::vector<int>{1, 0}, false, Device::cpu).item<float>(),
doctest::Approx(14.226707));
CHECK_EQ(
norm(x, -2.0, std::vector<int>{0, 1}, false, Device::cpu).item<float>(),
doctest::Approx(0.0));
CHECK_EQ(
norm(x, -2.0, std::vector<int>{1, 0}, false, Device::cpu).item<float>(),
doctest::Approx(0.0));
CHECK_EQ(norm(x, 1.0, std::vector<int>{0, 1}, true).shape(), Shape{1, 1});
CHECK_EQ(norm(x, 1.0, std::vector<int>{1, 0}, true).shape(), Shape{1, 1});
CHECK_EQ(norm(x, -1.0, std::vector<int>{0, 1}, true).shape(), Shape{1, 1});
CHECK_EQ(norm(x, -1.0, std::vector<int>{1, 0}, true).shape(), Shape{1, 1});
CHECK_EQ(
norm(x, 2.0, std::vector<int>{0, 1}, true, Device::cpu).shape(),
Shape{1, 1});
CHECK_EQ(
norm(x, 2.0, std::vector<int>{1, 0}, true, Device::cpu).shape(),
Shape{1, 1});
CHECK_EQ(
norm(x, -2.0, std::vector<int>{0, 1}, true, Device::cpu).shape(),
Shape{1, 1});
CHECK_EQ(
norm(x, -2.0, std::vector<int>{1, 0}, true, Device::cpu).shape(),
Shape{1, 1});
CHECK_EQ(
norm(x, -1.0, std::vector<int>{-2, -1}, false).item<float>(),
@@ -140,8 +164,14 @@ TEST_CASE("[mlx.core.linalg.norm] double ord") {
CHECK_EQ(
norm(x, 1.0, std::vector<int>{-2, -1}, false).item<float>(),
doctest::Approx(15.0));
CHECK_EQ(
norm(x, -2.0, std::vector<int>{-2, -1}, false, Device::cpu).item<float>(),
doctest::Approx(0.0));
CHECK_EQ(
norm(x, 2.0, std::vector<int>{-2, -1}, false, Device::cpu).item<float>(),
doctest::Approx(14.226707));
x = reshape(arange(18), {2, 3, 3});
x = reshape(arange(18, float32), {2, 3, 3});
CHECK_THROWS(norm(x, 2.0, std::vector{0, 1, 2}));
CHECK(allclose(
norm(x, 3.0, 0),
@@ -199,13 +229,31 @@ TEST_CASE("[mlx.core.linalg.norm] double ord") {
.item<bool>());
CHECK(allclose(norm(x, -1.0, std::vector<int>{1, 2}), array({9, 36}))
.item<bool>());
CHECK(allclose(
norm(x, 2.0, std::vector<int>{0, 1}, false, Device::cpu),
array({22.045408, 24.155825, 26.318918}))
.item<bool>());
CHECK(allclose(
norm(x, 2.0, std::vector<int>{1, 2}, false, Device::cpu),
array({14.226707, 39.759212}))
.item<bool>());
CHECK(allclose(
norm(x, -2.0, std::vector<int>{0, 1}, false, Device::cpu),
array({3, 2.7378995, 2.5128777}))
.item<bool>());
CHECK(allclose(
norm(x, -2.0, std::vector<int>{1, 2}, false, Device::cpu),
array({4.979028e-16, 7.009628e-16}),
/* rtol = */ 1e-5,
/* atol = */ 1e-6)
.item<bool>());
}
TEST_CASE("[mlx.core.linalg.norm] string ord") {
array x({1, 2, 3});
CHECK_THROWS(norm(x, "fro"));
x = reshape(arange(9), {3, 3});
x = reshape(arange(9, float32), {3, 3});
CHECK_THROWS(norm(x, "bad ord"));
CHECK_EQ(
@@ -214,8 +262,11 @@ TEST_CASE("[mlx.core.linalg.norm] string ord") {
CHECK_EQ(
norm(x, "fro", std::vector<int>{0, 1}).item<float>(),
doctest::Approx(14.2828568570857));
CHECK_EQ(
norm(x, "nuc", std::vector<int>{0, 1}, false, Device::cpu).item<float>(),
doctest::Approx(15.491934));
x = reshape(arange(18), {2, 3, 3});
x = reshape(arange(18, float32), {2, 3, 3});
CHECK(allclose(
norm(x, "fro", std::vector<int>{0, 1}),
array({22.24859546, 24.31049156, 26.43860813}))
@@ -240,6 +291,18 @@ TEST_CASE("[mlx.core.linalg.norm] string ord") {
norm(x, "f", std::vector<int>{2, 1}),
array({14.28285686, 39.7617907}))
.item<bool>());
CHECK(allclose(
norm(x, "nuc", std::vector<int>{0, 1}, false, Device::cpu),
array({25.045408, 26.893724, 28.831797}))
.item<bool>());
CHECK(allclose(
norm(x, "nuc", std::vector<int>{1, 2}, false, Device::cpu),
array({15.491934, 40.211937}))
.item<bool>());
CHECK(allclose(
norm(x, "nuc", std::vector<int>{-2, -1}, false, Device::cpu),
array({15.491934, 40.211937}))
.item<bool>());
}
TEST_CASE("test QR factorization") {
@@ -271,7 +334,7 @@ TEST_CASE("test SVD factorization") {
const auto prng_key = random::key(42);
const auto A = mlx::core::random::normal({5, 4}, prng_key);
const auto outs = linalg::svd(A, Device::cpu);
const auto outs = linalg::svd(A, true, Device::cpu);
CHECK_EQ(outs.size(), 3);
const auto& U = outs[0];
@@ -291,6 +354,15 @@ TEST_CASE("test SVD factorization") {
CHECK_EQ(U.dtype(), float32);
CHECK_EQ(S.dtype(), float32);
CHECK_EQ(Vt.dtype(), float32);
// Test singular values
const auto& outs_sv = linalg::svd(A, false, Device::cpu);
const auto SV = outs_sv[0];
CHECK_EQ(SV.shape(), Shape{4});
CHECK_EQ(SV.dtype(), float32);
CHECK(allclose(norm(SV), norm(A, "fro")).item<bool>());
}
TEST_CASE("test matrix inversion") {