diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index e9a0d6e5a..e9ce9399b 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -284,7 +284,7 @@ array tri_inv( const array& a, bool upper /* = false */, StreamOrDevice s /* = {} */) { - return inv_impl(a, /*tri=*/true, upper, s); + return inv_impl(upper ? triu(a) : tril(a), /*tri=*/true, upper, s); } array cholesky( diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index 67e8d7bf9..e4ddc1b43 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -170,10 +170,19 @@ class TestLinalg(mlx_tests.MLXTestCase): B = B.T AB = mx.stack([A, B]) invs = mx.linalg.tri_inv(AB, upper=upper, stream=mx.cpu) - for M, M_inv in zip(AB, invs): + diag_invs = mx.linalg.tri_inv(AB, upper=(not upper), stream=mx.cpu) + for M, M_inv, M_diag_invs in zip(AB, invs, diag_invs): self.assertTrue( mx.allclose(M @ M_inv, mx.eye(M.shape[0]), rtol=0, atol=1e-5) ) + self.assertTrue( + mx.allclose( + (mx.tril(M) if upper else mx.triu(M)) @ M_diag_invs, + mx.eye(M.shape[0]), + rtol=0, + atol=1e-5, + ) + ) def test_cholesky(self): sqrtA = mx.array(