From 344a29506ec8c42c8536a7acb64d4c8a52776758 Mon Sep 17 00:00:00 2001 From: Abe Leininger <95333017+abeleinin@users.noreply.github.com> Date: Wed, 19 Feb 2025 14:42:33 -0600 Subject: [PATCH] Enforce triangular matrix form in `tri_inv` (#1876) * fix tri_inv bug * Revert "fix tri_inv bug" This reverts commit b74b2902016204117040949231887f0622bc2c39. * Make sure that tri_inv returns a triangular matrix --------- Co-authored-by: Angelos Katharopoulos --- mlx/backend/cpu/inverse.cpp | 17 ++++++++++++++++- python/tests/test_linalg.py | 7 +++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/mlx/backend/cpu/inverse.cpp b/mlx/backend/cpu/inverse.cpp index 40cd16efc..81cabb79d 100644 --- a/mlx/backend/cpu/inverse.cpp +++ b/mlx/backend/cpu/inverse.cpp @@ -81,7 +81,22 @@ void general_inv(array& inv, int N, int i) { void tri_inv(array& inv, int N, int i, bool upper) { const char uplo = upper ? 'L' : 'U'; const char diag = 'N'; - int info = strtri_wrapper(uplo, diag, inv.data() + N * N * i, N); + float* data = inv.data() + N * N * i; + int info = strtri_wrapper(uplo, diag, data, N); + + // zero out the other triangle + if (upper) { + for (int i = 0; i < N; i++) { + std::fill(data, data + i, 0.0f); + data += N; + } + } else { + for (int i = 0; i < N; i++) { + std::fill(data + i + 1, data + N, 0.0f); + data += N; + } + } + if (info != 0) { std::stringstream ss; ss << "inverse_impl: triangular inversion failed with error code " << info; diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index 67e8d7bf9..bae3dc17a 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -175,6 +175,13 @@ class TestLinalg(mlx_tests.MLXTestCase): mx.allclose(M @ M_inv, mx.eye(M.shape[0]), rtol=0, atol=1e-5) ) + # Ensure that tri_inv will 0-out the supposedly 0 triangle + x = mx.random.normal((2, 8, 8)) + y1 = mx.linalg.tri_inv(x, upper=True, stream=mx.cpu) + y2 = mx.linalg.tri_inv(x, upper=False, stream=mx.cpu) + self.assertTrue(mx.all(y1 == mx.triu(y1))) + self.assertTrue(mx.all(y2 == mx.tril(y2))) + def test_cholesky(self): sqrtA = mx.array( [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], dtype=mx.float32