From b74b2902016204117040949231887f0622bc2c39 Mon Sep 17 00:00:00 2001 From: aleinin <95333017+abeleinin@users.noreply.github.com> Date: Mon, 17 Feb 2025 20:44:21 -0600 Subject: [PATCH] fix tri_inv bug --- mlx/linalg.cpp | 2 +- python/tests/test_linalg.py | 11 ++++++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index e9a0d6e5a..e9ce9399b 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -284,7 +284,7 @@ array tri_inv( const array& a, bool upper /* = false */, StreamOrDevice s /* = {} */) { - return inv_impl(a, /*tri=*/true, upper, s); + return inv_impl(upper ? triu(a) : tril(a), /*tri=*/true, upper, s); } array cholesky( diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index 67e8d7bf9..e4ddc1b43 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -170,10 +170,19 @@ class TestLinalg(mlx_tests.MLXTestCase): B = B.T AB = mx.stack([A, B]) invs = mx.linalg.tri_inv(AB, upper=upper, stream=mx.cpu) - for M, M_inv in zip(AB, invs): + diag_invs = mx.linalg.tri_inv(AB, upper=(not upper), stream=mx.cpu) + for M, M_inv, M_diag_invs in zip(AB, invs, diag_invs): self.assertTrue( mx.allclose(M @ M_inv, mx.eye(M.shape[0]), rtol=0, atol=1e-5) ) + self.assertTrue( + mx.allclose( + (mx.tril(M) if upper else mx.triu(M)) @ M_diag_invs, + mx.eye(M.shape[0]), + rtol=0, + atol=1e-5, + ) + ) def test_cholesky(self): sqrtA = mx.array(