mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
fix pinv (#2110)
This commit is contained in:
@@ -379,7 +379,12 @@ array pinv(const array& a, StreamOrDevice s /* = {} */) {
|
||||
// Prepare S
|
||||
S = expand_dims(S, -2, s);
|
||||
|
||||
return matmul(divide(V, S, s), U);
|
||||
auto rcond = 10. * std::max(m, n) * finfo(a.dtype()).eps;
|
||||
auto cutoff = multiply(array(rcond, a.dtype()), max(S, -1, true, s), s);
|
||||
auto rS =
|
||||
where(greater(S, cutoff, s), reciprocal(S, s), array(0.0f, a.dtype()), s);
|
||||
|
||||
return matmul(multiply(V, rS, s), U, s);
|
||||
}
|
||||
|
||||
array cholesky_inv(
|
||||
|
||||
Reference in New Issue
Block a user