mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-04 06:44:40 +08:00
Pinv (#875)
This commit is contained in:
@@ -181,6 +181,18 @@ class TestLinalg(mlx_tests.MLXTestCase):
|
||||
for M, L in zip(AB, Ls):
|
||||
self.assertTrue(mx.allclose(L @ L.T, M, rtol=1e-5, atol=1e-7))
|
||||
|
||||
def test_pseudo_inverse(self):
|
||||
A = mx.array([[1, 2, 3], [6, -5, 4], [-9, 8, 7]], dtype=mx.float32)
|
||||
A_plus = mx.linalg.pinv(A, stream=mx.cpu)
|
||||
self.assertTrue(mx.allclose(A @ A_plus @ A, A, rtol=0, atol=1e-5))
|
||||
|
||||
# Multiple matrices
|
||||
B = A - 100
|
||||
AB = mx.stack([A, B])
|
||||
pinvs = mx.linalg.pinv(AB, stream=mx.cpu)
|
||||
for M, M_plus in zip(AB, pinvs):
|
||||
self.assertTrue(mx.allclose(M @ M_plus @ M, M, rtol=0, atol=1e-3))
|
||||
|
||||
def test_cholesky_inv(self):
|
||||
mx.random.seed(7)
|
||||
|
||||
|
Reference in New Issue
Block a user