mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
fix pinv (#2110)
This commit is contained in:
parent
38c1e720c2
commit
fbc89e3ced
@ -379,7 +379,12 @@ array pinv(const array& a, StreamOrDevice s /* = {} */) {
|
|||||||
// Prepare S
|
// Prepare S
|
||||||
S = expand_dims(S, -2, s);
|
S = expand_dims(S, -2, s);
|
||||||
|
|
||||||
return matmul(divide(V, S, s), U);
|
auto rcond = 10. * std::max(m, n) * finfo(a.dtype()).eps;
|
||||||
|
auto cutoff = multiply(array(rcond, a.dtype()), max(S, -1, true, s), s);
|
||||||
|
auto rS =
|
||||||
|
where(greater(S, cutoff, s), reciprocal(S, s), array(0.0f, a.dtype()), s);
|
||||||
|
|
||||||
|
return matmul(multiply(V, rS, s), U, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
array cholesky_inv(
|
array cholesky_inv(
|
||||||
|
@ -33,6 +33,9 @@ struct numeric_limits<float16_t> {
|
|||||||
static constexpr float16_t max() {
|
static constexpr float16_t max() {
|
||||||
return bits_to_half(0x7BFF);
|
return bits_to_half(0x7BFF);
|
||||||
}
|
}
|
||||||
|
static constexpr float16_t epsilon() {
|
||||||
|
return bits_to_half(0x1400);
|
||||||
|
}
|
||||||
static constexpr float16_t infinity() {
|
static constexpr float16_t infinity() {
|
||||||
return bits_to_half(0x7C00);
|
return bits_to_half(0x7C00);
|
||||||
}
|
}
|
||||||
@ -56,6 +59,9 @@ struct numeric_limits<bfloat16_t> {
|
|||||||
static constexpr bfloat16_t max() {
|
static constexpr bfloat16_t max() {
|
||||||
return bits_to_bfloat(0x7F7F);
|
return bits_to_bfloat(0x7F7F);
|
||||||
}
|
}
|
||||||
|
static constexpr bfloat16_t epsilon() {
|
||||||
|
return bits_to_bfloat(0x3C00);
|
||||||
|
}
|
||||||
static constexpr bfloat16_t infinity() {
|
static constexpr bfloat16_t infinity() {
|
||||||
return bits_to_bfloat(0x7F80);
|
return bits_to_bfloat(0x7F80);
|
||||||
}
|
}
|
||||||
|
@ -283,9 +283,10 @@ int get_var(const char* name, int default_value) {
|
|||||||
} // namespace env
|
} // namespace env
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void set_finfo_limits(double& min, double& max) {
|
void set_finfo_limits(double& min, double& max, double& eps) {
|
||||||
min = numeric_limits<T>::lowest();
|
min = numeric_limits<T>::lowest();
|
||||||
max = numeric_limits<T>::max();
|
max = numeric_limits<T>::max();
|
||||||
|
eps = numeric_limits<T>::epsilon();
|
||||||
}
|
}
|
||||||
|
|
||||||
finfo::finfo(Dtype dtype) : dtype(dtype) {
|
finfo::finfo(Dtype dtype) : dtype(dtype) {
|
||||||
@ -295,16 +296,16 @@ finfo::finfo(Dtype dtype) : dtype(dtype) {
|
|||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
if (dtype == float32) {
|
if (dtype == float32) {
|
||||||
set_finfo_limits<float>(min, max);
|
set_finfo_limits<float>(min, max, eps);
|
||||||
} else if (dtype == float16) {
|
} else if (dtype == float16) {
|
||||||
set_finfo_limits<float16_t>(min, max);
|
set_finfo_limits<float16_t>(min, max, eps);
|
||||||
} else if (dtype == bfloat16) {
|
} else if (dtype == bfloat16) {
|
||||||
set_finfo_limits<bfloat16_t>(min, max);
|
set_finfo_limits<bfloat16_t>(min, max, eps);
|
||||||
} else if (dtype == float64) {
|
} else if (dtype == float64) {
|
||||||
set_finfo_limits<double>(min, max);
|
set_finfo_limits<double>(min, max, eps);
|
||||||
} else if (dtype == complex64) {
|
} else if (dtype == complex64) {
|
||||||
this->dtype = float32;
|
this->dtype = float32;
|
||||||
set_finfo_limits<float>(min, max);
|
set_finfo_limits<float>(min, max, eps);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -65,6 +65,7 @@ struct finfo {
|
|||||||
Dtype dtype;
|
Dtype dtype;
|
||||||
double min;
|
double min;
|
||||||
double max;
|
double max;
|
||||||
|
double eps;
|
||||||
};
|
};
|
||||||
|
|
||||||
/** Holds information about integral types. */
|
/** Holds information about integral types. */
|
||||||
|
@ -197,6 +197,13 @@ void init_array(nb::module_& m) {
|
|||||||
"max",
|
"max",
|
||||||
&mx::finfo::max,
|
&mx::finfo::max,
|
||||||
R"pbdoc(The largest representable number.)pbdoc")
|
R"pbdoc(The largest representable number.)pbdoc")
|
||||||
|
.def_ro(
|
||||||
|
"eps",
|
||||||
|
&mx::finfo::eps,
|
||||||
|
R"pbdoc(
|
||||||
|
The difference between 1.0 and the next smallest
|
||||||
|
representable number larger than 1.0.
|
||||||
|
)pbdoc")
|
||||||
.def_ro("dtype", &mx::finfo::dtype, R"pbdoc(The :obj:`Dtype`.)pbdoc")
|
.def_ro("dtype", &mx::finfo::dtype, R"pbdoc(The :obj:`Dtype`.)pbdoc")
|
||||||
.def("__repr__", [](const mx::finfo& f) {
|
.def("__repr__", [](const mx::finfo& f) {
|
||||||
std::ostringstream os;
|
std::ostringstream os;
|
||||||
|
@ -103,10 +103,12 @@ class TestDtypes(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
self.assertEqual(mx.finfo(mx.float32).min, np.finfo(np.float32).min)
|
self.assertEqual(mx.finfo(mx.float32).min, np.finfo(np.float32).min)
|
||||||
self.assertEqual(mx.finfo(mx.float32).max, np.finfo(np.float32).max)
|
self.assertEqual(mx.finfo(mx.float32).max, np.finfo(np.float32).max)
|
||||||
|
self.assertEqual(mx.finfo(mx.float32).eps, np.finfo(np.float32).eps)
|
||||||
self.assertEqual(mx.finfo(mx.float32).dtype, mx.float32)
|
self.assertEqual(mx.finfo(mx.float32).dtype, mx.float32)
|
||||||
|
|
||||||
self.assertEqual(mx.finfo(mx.float16).min, np.finfo(np.float16).min)
|
self.assertEqual(mx.finfo(mx.float16).min, np.finfo(np.float16).min)
|
||||||
self.assertEqual(mx.finfo(mx.float16).max, np.finfo(np.float16).max)
|
self.assertEqual(mx.finfo(mx.float16).max, np.finfo(np.float16).max)
|
||||||
|
self.assertEqual(mx.finfo(mx.float16).eps, np.finfo(np.float16).eps)
|
||||||
self.assertEqual(mx.finfo(mx.float16).dtype, mx.float16)
|
self.assertEqual(mx.finfo(mx.float16).dtype, mx.float16)
|
||||||
|
|
||||||
def test_iinfo(self):
|
def test_iinfo(self):
|
||||||
|
@ -232,6 +232,11 @@ class TestLinalg(mlx_tests.MLXTestCase):
|
|||||||
for M, M_plus in zip(AB, pinvs):
|
for M, M_plus in zip(AB, pinvs):
|
||||||
self.assertTrue(mx.allclose(M @ M_plus @ M, M, rtol=0, atol=1e-3))
|
self.assertTrue(mx.allclose(M @ M_plus @ M, M, rtol=0, atol=1e-3))
|
||||||
|
|
||||||
|
# Test singular matrix
|
||||||
|
A = mx.array([[4.0, 1.0], [4.0, 1.0]])
|
||||||
|
A_plus = mx.linalg.pinv(A, stream=mx.cpu)
|
||||||
|
self.assertTrue(mx.allclose(A @ A_plus @ A, A))
|
||||||
|
|
||||||
def test_cholesky_inv(self):
|
def test_cholesky_inv(self):
|
||||||
mx.random.seed(7)
|
mx.random.seed(7)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user