mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
fix pinv (#2110)
This commit is contained in:
@@ -379,7 +379,12 @@ array pinv(const array& a, StreamOrDevice s /* = {} */) {
|
||||
// Prepare S
|
||||
S = expand_dims(S, -2, s);
|
||||
|
||||
return matmul(divide(V, S, s), U);
|
||||
auto rcond = 10. * std::max(m, n) * finfo(a.dtype()).eps;
|
||||
auto cutoff = multiply(array(rcond, a.dtype()), max(S, -1, true, s), s);
|
||||
auto rS =
|
||||
where(greater(S, cutoff, s), reciprocal(S, s), array(0.0f, a.dtype()), s);
|
||||
|
||||
return matmul(multiply(V, rS, s), U, s);
|
||||
}
|
||||
|
||||
array cholesky_inv(
|
||||
|
||||
@@ -33,6 +33,9 @@ struct numeric_limits<float16_t> {
|
||||
static constexpr float16_t max() {
|
||||
return bits_to_half(0x7BFF);
|
||||
}
|
||||
static constexpr float16_t epsilon() {
|
||||
return bits_to_half(0x1400);
|
||||
}
|
||||
static constexpr float16_t infinity() {
|
||||
return bits_to_half(0x7C00);
|
||||
}
|
||||
@@ -56,6 +59,9 @@ struct numeric_limits<bfloat16_t> {
|
||||
static constexpr bfloat16_t max() {
|
||||
return bits_to_bfloat(0x7F7F);
|
||||
}
|
||||
static constexpr bfloat16_t epsilon() {
|
||||
return bits_to_bfloat(0x3C00);
|
||||
}
|
||||
static constexpr bfloat16_t infinity() {
|
||||
return bits_to_bfloat(0x7F80);
|
||||
}
|
||||
|
||||
@@ -283,9 +283,10 @@ int get_var(const char* name, int default_value) {
|
||||
} // namespace env
|
||||
|
||||
template <typename T>
|
||||
void set_finfo_limits(double& min, double& max) {
|
||||
void set_finfo_limits(double& min, double& max, double& eps) {
|
||||
min = numeric_limits<T>::lowest();
|
||||
max = numeric_limits<T>::max();
|
||||
eps = numeric_limits<T>::epsilon();
|
||||
}
|
||||
|
||||
finfo::finfo(Dtype dtype) : dtype(dtype) {
|
||||
@@ -295,16 +296,16 @@ finfo::finfo(Dtype dtype) : dtype(dtype) {
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (dtype == float32) {
|
||||
set_finfo_limits<float>(min, max);
|
||||
set_finfo_limits<float>(min, max, eps);
|
||||
} else if (dtype == float16) {
|
||||
set_finfo_limits<float16_t>(min, max);
|
||||
set_finfo_limits<float16_t>(min, max, eps);
|
||||
} else if (dtype == bfloat16) {
|
||||
set_finfo_limits<bfloat16_t>(min, max);
|
||||
set_finfo_limits<bfloat16_t>(min, max, eps);
|
||||
} else if (dtype == float64) {
|
||||
set_finfo_limits<double>(min, max);
|
||||
set_finfo_limits<double>(min, max, eps);
|
||||
} else if (dtype == complex64) {
|
||||
this->dtype = float32;
|
||||
set_finfo_limits<float>(min, max);
|
||||
set_finfo_limits<float>(min, max, eps);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -65,6 +65,7 @@ struct finfo {
|
||||
Dtype dtype;
|
||||
double min;
|
||||
double max;
|
||||
double eps;
|
||||
};
|
||||
|
||||
/** Holds information about integral types. */
|
||||
|
||||
Reference in New Issue
Block a user