mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-28 22:28:11 +08:00
@@ -100,7 +100,7 @@ TEST_CASE("[mlx.core.linalg.norm] double ord") {
|
||||
norm(x, -std::numeric_limits<double>::infinity()).item<float>(),
|
||||
doctest::Approx(expected));
|
||||
|
||||
x = reshape(arange(9), {3, 3});
|
||||
x = reshape(arange(9, float32), {3, 3});
|
||||
|
||||
CHECK(allclose(
|
||||
norm(x, 2.0, 0, false),
|
||||
@@ -129,10 +129,34 @@ TEST_CASE("[mlx.core.linalg.norm] double ord") {
|
||||
CHECK_EQ(
|
||||
norm(x, -1.0, std::vector<int>{1, 0}).item<float>(),
|
||||
doctest::Approx(3.0));
|
||||
CHECK_EQ(
|
||||
norm(x, 2.0, std::vector<int>{0, 1}, false, Device::cpu).item<float>(),
|
||||
doctest::Approx(14.226707));
|
||||
CHECK_EQ(
|
||||
norm(x, 2.0, std::vector<int>{1, 0}, false, Device::cpu).item<float>(),
|
||||
doctest::Approx(14.226707));
|
||||
CHECK_EQ(
|
||||
norm(x, -2.0, std::vector<int>{0, 1}, false, Device::cpu).item<float>(),
|
||||
doctest::Approx(0.0));
|
||||
CHECK_EQ(
|
||||
norm(x, -2.0, std::vector<int>{1, 0}, false, Device::cpu).item<float>(),
|
||||
doctest::Approx(0.0));
|
||||
CHECK_EQ(norm(x, 1.0, std::vector<int>{0, 1}, true).shape(), Shape{1, 1});
|
||||
CHECK_EQ(norm(x, 1.0, std::vector<int>{1, 0}, true).shape(), Shape{1, 1});
|
||||
CHECK_EQ(norm(x, -1.0, std::vector<int>{0, 1}, true).shape(), Shape{1, 1});
|
||||
CHECK_EQ(norm(x, -1.0, std::vector<int>{1, 0}, true).shape(), Shape{1, 1});
|
||||
CHECK_EQ(
|
||||
norm(x, 2.0, std::vector<int>{0, 1}, true, Device::cpu).shape(),
|
||||
Shape{1, 1});
|
||||
CHECK_EQ(
|
||||
norm(x, 2.0, std::vector<int>{1, 0}, true, Device::cpu).shape(),
|
||||
Shape{1, 1});
|
||||
CHECK_EQ(
|
||||
norm(x, -2.0, std::vector<int>{0, 1}, true, Device::cpu).shape(),
|
||||
Shape{1, 1});
|
||||
CHECK_EQ(
|
||||
norm(x, -2.0, std::vector<int>{1, 0}, true, Device::cpu).shape(),
|
||||
Shape{1, 1});
|
||||
|
||||
CHECK_EQ(
|
||||
norm(x, -1.0, std::vector<int>{-2, -1}, false).item<float>(),
|
||||
@@ -140,8 +164,14 @@ TEST_CASE("[mlx.core.linalg.norm] double ord") {
|
||||
CHECK_EQ(
|
||||
norm(x, 1.0, std::vector<int>{-2, -1}, false).item<float>(),
|
||||
doctest::Approx(15.0));
|
||||
CHECK_EQ(
|
||||
norm(x, -2.0, std::vector<int>{-2, -1}, false, Device::cpu).item<float>(),
|
||||
doctest::Approx(0.0));
|
||||
CHECK_EQ(
|
||||
norm(x, 2.0, std::vector<int>{-2, -1}, false, Device::cpu).item<float>(),
|
||||
doctest::Approx(14.226707));
|
||||
|
||||
x = reshape(arange(18), {2, 3, 3});
|
||||
x = reshape(arange(18, float32), {2, 3, 3});
|
||||
CHECK_THROWS(norm(x, 2.0, std::vector{0, 1, 2}));
|
||||
CHECK(allclose(
|
||||
norm(x, 3.0, 0),
|
||||
@@ -199,13 +229,31 @@ TEST_CASE("[mlx.core.linalg.norm] double ord") {
|
||||
.item<bool>());
|
||||
CHECK(allclose(norm(x, -1.0, std::vector<int>{1, 2}), array({9, 36}))
|
||||
.item<bool>());
|
||||
CHECK(allclose(
|
||||
norm(x, 2.0, std::vector<int>{0, 1}, false, Device::cpu),
|
||||
array({22.045408, 24.155825, 26.318918}))
|
||||
.item<bool>());
|
||||
CHECK(allclose(
|
||||
norm(x, 2.0, std::vector<int>{1, 2}, false, Device::cpu),
|
||||
array({14.226707, 39.759212}))
|
||||
.item<bool>());
|
||||
CHECK(allclose(
|
||||
norm(x, -2.0, std::vector<int>{0, 1}, false, Device::cpu),
|
||||
array({3, 2.7378995, 2.5128777}))
|
||||
.item<bool>());
|
||||
CHECK(allclose(
|
||||
norm(x, -2.0, std::vector<int>{1, 2}, false, Device::cpu),
|
||||
array({4.979028e-16, 7.009628e-16}),
|
||||
/* rtol = */ 1e-5,
|
||||
/* atol = */ 1e-6)
|
||||
.item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("[mlx.core.linalg.norm] string ord") {
|
||||
array x({1, 2, 3});
|
||||
CHECK_THROWS(norm(x, "fro"));
|
||||
|
||||
x = reshape(arange(9), {3, 3});
|
||||
x = reshape(arange(9, float32), {3, 3});
|
||||
CHECK_THROWS(norm(x, "bad ord"));
|
||||
|
||||
CHECK_EQ(
|
||||
@@ -214,8 +262,11 @@ TEST_CASE("[mlx.core.linalg.norm] string ord") {
|
||||
CHECK_EQ(
|
||||
norm(x, "fro", std::vector<int>{0, 1}).item<float>(),
|
||||
doctest::Approx(14.2828568570857));
|
||||
CHECK_EQ(
|
||||
norm(x, "nuc", std::vector<int>{0, 1}, false, Device::cpu).item<float>(),
|
||||
doctest::Approx(15.491934));
|
||||
|
||||
x = reshape(arange(18), {2, 3, 3});
|
||||
x = reshape(arange(18, float32), {2, 3, 3});
|
||||
CHECK(allclose(
|
||||
norm(x, "fro", std::vector<int>{0, 1}),
|
||||
array({22.24859546, 24.31049156, 26.43860813}))
|
||||
@@ -240,6 +291,18 @@ TEST_CASE("[mlx.core.linalg.norm] string ord") {
|
||||
norm(x, "f", std::vector<int>{2, 1}),
|
||||
array({14.28285686, 39.7617907}))
|
||||
.item<bool>());
|
||||
CHECK(allclose(
|
||||
norm(x, "nuc", std::vector<int>{0, 1}, false, Device::cpu),
|
||||
array({25.045408, 26.893724, 28.831797}))
|
||||
.item<bool>());
|
||||
CHECK(allclose(
|
||||
norm(x, "nuc", std::vector<int>{1, 2}, false, Device::cpu),
|
||||
array({15.491934, 40.211937}))
|
||||
.item<bool>());
|
||||
CHECK(allclose(
|
||||
norm(x, "nuc", std::vector<int>{-2, -1}, false, Device::cpu),
|
||||
array({15.491934, 40.211937}))
|
||||
.item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test QR factorization") {
|
||||
@@ -271,7 +334,7 @@ TEST_CASE("test SVD factorization") {
|
||||
|
||||
const auto prng_key = random::key(42);
|
||||
const auto A = mlx::core::random::normal({5, 4}, prng_key);
|
||||
const auto outs = linalg::svd(A, Device::cpu);
|
||||
const auto outs = linalg::svd(A, true, Device::cpu);
|
||||
CHECK_EQ(outs.size(), 3);
|
||||
|
||||
const auto& U = outs[0];
|
||||
@@ -291,6 +354,15 @@ TEST_CASE("test SVD factorization") {
|
||||
CHECK_EQ(U.dtype(), float32);
|
||||
CHECK_EQ(S.dtype(), float32);
|
||||
CHECK_EQ(Vt.dtype(), float32);
|
||||
|
||||
// Test singular values
|
||||
const auto& outs_sv = linalg::svd(A, false, Device::cpu);
|
||||
const auto SV = outs_sv[0];
|
||||
|
||||
CHECK_EQ(SV.shape(), Shape{4});
|
||||
CHECK_EQ(SV.dtype(), float32);
|
||||
|
||||
CHECK(allclose(norm(SV), norm(A, "fro")).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test matrix inversion") {
|
||||
|
||||
@@ -466,15 +466,19 @@ TEST_CASE("test vmap scatter") {
|
||||
}
|
||||
|
||||
TEST_CASE("test vmap SVD") {
|
||||
auto fun = [](std::vector<array> inputs) {
|
||||
return linalg::svd(inputs.at(0), Device::cpu);
|
||||
auto svd_full = [](std::vector<array> inputs) {
|
||||
return linalg::svd(inputs.at(0), true, Device::cpu);
|
||||
};
|
||||
|
||||
auto svd_singular = [](std::vector<array> inputs) {
|
||||
return linalg::svd(inputs.at(0), false, Device::cpu);
|
||||
};
|
||||
|
||||
auto a = astype(reshape(arange(24), {3, 4, 2}), float32);
|
||||
|
||||
// vmap over the second axis.
|
||||
{
|
||||
auto out = vmap(fun, /* in_axes = */ {1})({a});
|
||||
auto out = vmap(svd_full, /* in_axes = */ {1})({a});
|
||||
const auto& U = out.at(0);
|
||||
const auto& S = out.at(1);
|
||||
const auto& Vt = out.at(2);
|
||||
@@ -486,7 +490,7 @@ TEST_CASE("test vmap SVD") {
|
||||
|
||||
// vmap over the third axis.
|
||||
{
|
||||
auto out = vmap(fun, /* in_axes = */ {2})({a});
|
||||
auto out = vmap(svd_full, /* in_axes = */ {2})({a});
|
||||
const auto& U = out.at(0);
|
||||
const auto& S = out.at(1);
|
||||
const auto& Vt = out.at(2);
|
||||
@@ -495,6 +499,21 @@ TEST_CASE("test vmap SVD") {
|
||||
CHECK_EQ(S.shape(), Shape{a.shape(2), a.shape(0)});
|
||||
CHECK_EQ(Vt.shape(), Shape{a.shape(2), a.shape(1), a.shape(1)});
|
||||
}
|
||||
|
||||
// test singular values
|
||||
{
|
||||
auto out = vmap(svd_singular, /* in_axes = */ {1})({a});
|
||||
const auto& S = out.at(0);
|
||||
|
||||
CHECK_EQ(S.shape(), Shape{a.shape(1), a.shape(2)});
|
||||
}
|
||||
|
||||
{
|
||||
auto out = vmap(svd_singular, /* in_axes = */ {2})({a});
|
||||
const auto& S = out.at(0);
|
||||
|
||||
CHECK_EQ(S.shape(), Shape{a.shape(2), a.shape(0)});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test vmap dynamic slices") {
|
||||
|
||||
Reference in New Issue
Block a user