mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
fix tri_inv bug
This commit is contained in:
parent
4c1dfa58b7
commit
b74b290201
@ -284,7 +284,7 @@ array tri_inv(
|
|||||||
const array& a,
|
const array& a,
|
||||||
bool upper /* = false */,
|
bool upper /* = false */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
return inv_impl(a, /*tri=*/true, upper, s);
|
return inv_impl(upper ? triu(a) : tril(a), /*tri=*/true, upper, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
array cholesky(
|
array cholesky(
|
||||||
|
@ -170,10 +170,19 @@ class TestLinalg(mlx_tests.MLXTestCase):
|
|||||||
B = B.T
|
B = B.T
|
||||||
AB = mx.stack([A, B])
|
AB = mx.stack([A, B])
|
||||||
invs = mx.linalg.tri_inv(AB, upper=upper, stream=mx.cpu)
|
invs = mx.linalg.tri_inv(AB, upper=upper, stream=mx.cpu)
|
||||||
for M, M_inv in zip(AB, invs):
|
diag_invs = mx.linalg.tri_inv(AB, upper=(not upper), stream=mx.cpu)
|
||||||
|
for M, M_inv, M_diag_invs in zip(AB, invs, diag_invs):
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
mx.allclose(M @ M_inv, mx.eye(M.shape[0]), rtol=0, atol=1e-5)
|
mx.allclose(M @ M_inv, mx.eye(M.shape[0]), rtol=0, atol=1e-5)
|
||||||
)
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(
|
||||||
|
(mx.tril(M) if upper else mx.triu(M)) @ M_diag_invs,
|
||||||
|
mx.eye(M.shape[0]),
|
||||||
|
rtol=0,
|
||||||
|
atol=1e-5,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def test_cholesky(self):
|
def test_cholesky(self):
|
||||||
sqrtA = mx.array(
|
sqrtA = mx.array(
|
||||||
|
Loading…
Reference in New Issue
Block a user