mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
fix tri_inv bug
This commit is contained in:
parent
4c1dfa58b7
commit
b74b290201
@ -284,7 +284,7 @@ array tri_inv(
|
||||
const array& a,
|
||||
bool upper /* = false */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return inv_impl(a, /*tri=*/true, upper, s);
|
||||
return inv_impl(upper ? triu(a) : tril(a), /*tri=*/true, upper, s);
|
||||
}
|
||||
|
||||
array cholesky(
|
||||
|
@ -170,10 +170,19 @@ class TestLinalg(mlx_tests.MLXTestCase):
|
||||
B = B.T
|
||||
AB = mx.stack([A, B])
|
||||
invs = mx.linalg.tri_inv(AB, upper=upper, stream=mx.cpu)
|
||||
for M, M_inv in zip(AB, invs):
|
||||
diag_invs = mx.linalg.tri_inv(AB, upper=(not upper), stream=mx.cpu)
|
||||
for M, M_inv, M_diag_invs in zip(AB, invs, diag_invs):
|
||||
self.assertTrue(
|
||||
mx.allclose(M @ M_inv, mx.eye(M.shape[0]), rtol=0, atol=1e-5)
|
||||
)
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
(mx.tril(M) if upper else mx.triu(M)) @ M_diag_invs,
|
||||
mx.eye(M.shape[0]),
|
||||
rtol=0,
|
||||
atol=1e-5,
|
||||
)
|
||||
)
|
||||
|
||||
def test_cholesky(self):
|
||||
sqrtA = mx.array(
|
||||
|
Loading…
Reference in New Issue
Block a user