From fbc89e3ced24f8a8bf0324bf691ce53da9243868 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 23 Apr 2025 13:08:28 -0700 Subject: [PATCH] fix pinv (#2110) --- mlx/linalg.cpp | 7 ++++++- mlx/types/limits.h | 6 ++++++ mlx/utils.cpp | 13 +++++++------ mlx/utils.h | 1 + python/src/array.cpp | 7 +++++++ python/tests/test_array.py | 2 ++ python/tests/test_linalg.py | 5 +++++ 7 files changed, 34 insertions(+), 7 deletions(-) diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index 5b9b51ad3..53f13486a 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -379,7 +379,12 @@ array pinv(const array& a, StreamOrDevice s /* = {} */) { // Prepare S S = expand_dims(S, -2, s); - return matmul(divide(V, S, s), U); + auto rcond = 10. * std::max(m, n) * finfo(a.dtype()).eps; + auto cutoff = multiply(array(rcond, a.dtype()), max(S, -1, true, s), s); + auto rS = + where(greater(S, cutoff, s), reciprocal(S, s), array(0.0f, a.dtype()), s); + + return matmul(multiply(V, rS, s), U, s); } array cholesky_inv( diff --git a/mlx/types/limits.h b/mlx/types/limits.h index 7e0de15bc..5f2b1e9e0 100644 --- a/mlx/types/limits.h +++ b/mlx/types/limits.h @@ -33,6 +33,9 @@ struct numeric_limits { static constexpr float16_t max() { return bits_to_half(0x7BFF); } + static constexpr float16_t epsilon() { + return bits_to_half(0x1400); + } static constexpr float16_t infinity() { return bits_to_half(0x7C00); } @@ -56,6 +59,9 @@ struct numeric_limits { static constexpr bfloat16_t max() { return bits_to_bfloat(0x7F7F); } + static constexpr bfloat16_t epsilon() { + return bits_to_bfloat(0x3C00); + } static constexpr bfloat16_t infinity() { return bits_to_bfloat(0x7F80); } diff --git a/mlx/utils.cpp b/mlx/utils.cpp index 188584174..0b2e66352 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -283,9 +283,10 @@ int get_var(const char* name, int default_value) { } // namespace env template -void set_finfo_limits(double& min, double& max) { +void set_finfo_limits(double& min, double& max, double& eps) { min = numeric_limits::lowest(); max = numeric_limits::max(); + eps = numeric_limits::epsilon(); } finfo::finfo(Dtype dtype) : dtype(dtype) { @@ -295,16 +296,16 @@ finfo::finfo(Dtype dtype) : dtype(dtype) { throw std::invalid_argument(msg.str()); } if (dtype == float32) { - set_finfo_limits(min, max); + set_finfo_limits(min, max, eps); } else if (dtype == float16) { - set_finfo_limits(min, max); + set_finfo_limits(min, max, eps); } else if (dtype == bfloat16) { - set_finfo_limits(min, max); + set_finfo_limits(min, max, eps); } else if (dtype == float64) { - set_finfo_limits(min, max); + set_finfo_limits(min, max, eps); } else if (dtype == complex64) { this->dtype = float32; - set_finfo_limits(min, max); + set_finfo_limits(min, max, eps); } } diff --git a/mlx/utils.h b/mlx/utils.h index 19241e4c6..f0aa7c2de 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -65,6 +65,7 @@ struct finfo { Dtype dtype; double min; double max; + double eps; }; /** Holds information about integral types. */ diff --git a/python/src/array.cpp b/python/src/array.cpp index 467bd0fa5..5f8dbe021 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -197,6 +197,13 @@ void init_array(nb::module_& m) { "max", &mx::finfo::max, R"pbdoc(The largest representable number.)pbdoc") + .def_ro( + "eps", + &mx::finfo::eps, + R"pbdoc( + The difference between 1.0 and the next smallest + representable number larger than 1.0. + )pbdoc") .def_ro("dtype", &mx::finfo::dtype, R"pbdoc(The :obj:`Dtype`.)pbdoc") .def("__repr__", [](const mx::finfo& f) { std::ostringstream os; diff --git a/python/tests/test_array.py b/python/tests/test_array.py index fa5784ea9..792e666d6 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -103,10 +103,12 @@ class TestDtypes(mlx_tests.MLXTestCase): self.assertEqual(mx.finfo(mx.float32).min, np.finfo(np.float32).min) self.assertEqual(mx.finfo(mx.float32).max, np.finfo(np.float32).max) + self.assertEqual(mx.finfo(mx.float32).eps, np.finfo(np.float32).eps) self.assertEqual(mx.finfo(mx.float32).dtype, mx.float32) self.assertEqual(mx.finfo(mx.float16).min, np.finfo(np.float16).min) self.assertEqual(mx.finfo(mx.float16).max, np.finfo(np.float16).max) + self.assertEqual(mx.finfo(mx.float16).eps, np.finfo(np.float16).eps) self.assertEqual(mx.finfo(mx.float16).dtype, mx.float16) def test_iinfo(self): diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index ffa355c10..a9fe572af 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -232,6 +232,11 @@ class TestLinalg(mlx_tests.MLXTestCase): for M, M_plus in zip(AB, pinvs): self.assertTrue(mx.allclose(M @ M_plus @ M, M, rtol=0, atol=1e-3)) + # Test singular matrix + A = mx.array([[4.0, 1.0], [4.0, 1.0]]) + A_plus = mx.linalg.pinv(A, stream=mx.cpu) + self.assertTrue(mx.allclose(A @ A_plus @ A, A)) + def test_cholesky_inv(self): mx.random.seed(7)