fix tri_inv bug

This commit is contained in:
aleinin 2025-02-17 20:44:21 -06:00
parent 4c1dfa58b7
commit b74b290201
2 changed files with 11 additions and 2 deletions

View File

@ -284,7 +284,7 @@ array tri_inv(
const array& a, const array& a,
bool upper /* = false */, bool upper /* = false */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
return inv_impl(a, /*tri=*/true, upper, s); return inv_impl(upper ? triu(a) : tril(a), /*tri=*/true, upper, s);
} }
array cholesky( array cholesky(

View File

@ -170,10 +170,19 @@ class TestLinalg(mlx_tests.MLXTestCase):
B = B.T B = B.T
AB = mx.stack([A, B]) AB = mx.stack([A, B])
invs = mx.linalg.tri_inv(AB, upper=upper, stream=mx.cpu) invs = mx.linalg.tri_inv(AB, upper=upper, stream=mx.cpu)
for M, M_inv in zip(AB, invs): diag_invs = mx.linalg.tri_inv(AB, upper=(not upper), stream=mx.cpu)
for M, M_inv, M_diag_invs in zip(AB, invs, diag_invs):
self.assertTrue( self.assertTrue(
mx.allclose(M @ M_inv, mx.eye(M.shape[0]), rtol=0, atol=1e-5) mx.allclose(M @ M_inv, mx.eye(M.shape[0]), rtol=0, atol=1e-5)
) )
self.assertTrue(
mx.allclose(
(mx.tril(M) if upper else mx.triu(M)) @ M_diag_invs,
mx.eye(M.shape[0]),
rtol=0,
atol=1e-5,
)
)
def test_cholesky(self): def test_cholesky(self):
sqrtA = mx.array( sqrtA = mx.array(