mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
fix pinv (#2110)
This commit is contained in:
parent
38c1e720c2
commit
fbc89e3ced
@ -379,7 +379,12 @@ array pinv(const array& a, StreamOrDevice s /* = {} */) {
|
||||
// Prepare S
|
||||
S = expand_dims(S, -2, s);
|
||||
|
||||
return matmul(divide(V, S, s), U);
|
||||
auto rcond = 10. * std::max(m, n) * finfo(a.dtype()).eps;
|
||||
auto cutoff = multiply(array(rcond, a.dtype()), max(S, -1, true, s), s);
|
||||
auto rS =
|
||||
where(greater(S, cutoff, s), reciprocal(S, s), array(0.0f, a.dtype()), s);
|
||||
|
||||
return matmul(multiply(V, rS, s), U, s);
|
||||
}
|
||||
|
||||
array cholesky_inv(
|
||||
|
@ -33,6 +33,9 @@ struct numeric_limits<float16_t> {
|
||||
static constexpr float16_t max() {
|
||||
return bits_to_half(0x7BFF);
|
||||
}
|
||||
static constexpr float16_t epsilon() {
|
||||
return bits_to_half(0x1400);
|
||||
}
|
||||
static constexpr float16_t infinity() {
|
||||
return bits_to_half(0x7C00);
|
||||
}
|
||||
@ -56,6 +59,9 @@ struct numeric_limits<bfloat16_t> {
|
||||
static constexpr bfloat16_t max() {
|
||||
return bits_to_bfloat(0x7F7F);
|
||||
}
|
||||
static constexpr bfloat16_t epsilon() {
|
||||
return bits_to_bfloat(0x3C00);
|
||||
}
|
||||
static constexpr bfloat16_t infinity() {
|
||||
return bits_to_bfloat(0x7F80);
|
||||
}
|
||||
|
@ -283,9 +283,10 @@ int get_var(const char* name, int default_value) {
|
||||
} // namespace env
|
||||
|
||||
template <typename T>
|
||||
void set_finfo_limits(double& min, double& max) {
|
||||
void set_finfo_limits(double& min, double& max, double& eps) {
|
||||
min = numeric_limits<T>::lowest();
|
||||
max = numeric_limits<T>::max();
|
||||
eps = numeric_limits<T>::epsilon();
|
||||
}
|
||||
|
||||
finfo::finfo(Dtype dtype) : dtype(dtype) {
|
||||
@ -295,16 +296,16 @@ finfo::finfo(Dtype dtype) : dtype(dtype) {
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (dtype == float32) {
|
||||
set_finfo_limits<float>(min, max);
|
||||
set_finfo_limits<float>(min, max, eps);
|
||||
} else if (dtype == float16) {
|
||||
set_finfo_limits<float16_t>(min, max);
|
||||
set_finfo_limits<float16_t>(min, max, eps);
|
||||
} else if (dtype == bfloat16) {
|
||||
set_finfo_limits<bfloat16_t>(min, max);
|
||||
set_finfo_limits<bfloat16_t>(min, max, eps);
|
||||
} else if (dtype == float64) {
|
||||
set_finfo_limits<double>(min, max);
|
||||
set_finfo_limits<double>(min, max, eps);
|
||||
} else if (dtype == complex64) {
|
||||
this->dtype = float32;
|
||||
set_finfo_limits<float>(min, max);
|
||||
set_finfo_limits<float>(min, max, eps);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -65,6 +65,7 @@ struct finfo {
|
||||
Dtype dtype;
|
||||
double min;
|
||||
double max;
|
||||
double eps;
|
||||
};
|
||||
|
||||
/** Holds information about integral types. */
|
||||
|
@ -197,6 +197,13 @@ void init_array(nb::module_& m) {
|
||||
"max",
|
||||
&mx::finfo::max,
|
||||
R"pbdoc(The largest representable number.)pbdoc")
|
||||
.def_ro(
|
||||
"eps",
|
||||
&mx::finfo::eps,
|
||||
R"pbdoc(
|
||||
The difference between 1.0 and the next smallest
|
||||
representable number larger than 1.0.
|
||||
)pbdoc")
|
||||
.def_ro("dtype", &mx::finfo::dtype, R"pbdoc(The :obj:`Dtype`.)pbdoc")
|
||||
.def("__repr__", [](const mx::finfo& f) {
|
||||
std::ostringstream os;
|
||||
|
@ -103,10 +103,12 @@ class TestDtypes(mlx_tests.MLXTestCase):
|
||||
|
||||
self.assertEqual(mx.finfo(mx.float32).min, np.finfo(np.float32).min)
|
||||
self.assertEqual(mx.finfo(mx.float32).max, np.finfo(np.float32).max)
|
||||
self.assertEqual(mx.finfo(mx.float32).eps, np.finfo(np.float32).eps)
|
||||
self.assertEqual(mx.finfo(mx.float32).dtype, mx.float32)
|
||||
|
||||
self.assertEqual(mx.finfo(mx.float16).min, np.finfo(np.float16).min)
|
||||
self.assertEqual(mx.finfo(mx.float16).max, np.finfo(np.float16).max)
|
||||
self.assertEqual(mx.finfo(mx.float16).eps, np.finfo(np.float16).eps)
|
||||
self.assertEqual(mx.finfo(mx.float16).dtype, mx.float16)
|
||||
|
||||
def test_iinfo(self):
|
||||
|
@ -232,6 +232,11 @@ class TestLinalg(mlx_tests.MLXTestCase):
|
||||
for M, M_plus in zip(AB, pinvs):
|
||||
self.assertTrue(mx.allclose(M @ M_plus @ M, M, rtol=0, atol=1e-3))
|
||||
|
||||
# Test singular matrix
|
||||
A = mx.array([[4.0, 1.0], [4.0, 1.0]])
|
||||
A_plus = mx.linalg.pinv(A, stream=mx.cpu)
|
||||
self.assertTrue(mx.allclose(A @ A_plus @ A, A))
|
||||
|
||||
def test_cholesky_inv(self):
|
||||
mx.random.seed(7)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user