mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-11 06:24:35 +08:00
Double for lapack (#1904)
* double for lapack ops * add double support for lapack ops
This commit is contained in:
@@ -183,6 +183,115 @@ class TestDouble(mlx_tests.MLXTestCase):
|
||||
c = a + b
|
||||
self.assertEqual(c.dtype, mx.float64)
|
||||
|
||||
def test_lapack(self):
|
||||
with mx.stream(mx.cpu):
|
||||
# QRF
|
||||
A = mx.array([[2.0, 3.0], [1.0, 2.0]], dtype=mx.float64)
|
||||
Q, R = mx.linalg.qr(A)
|
||||
out = Q @ R
|
||||
self.assertTrue(mx.allclose(out, A))
|
||||
out = Q.T @ Q
|
||||
self.assertTrue(mx.allclose(out, mx.eye(2)))
|
||||
self.assertTrue(mx.allclose(mx.tril(R, -1), mx.zeros_like(R)))
|
||||
self.assertEqual(Q.dtype, mx.float64)
|
||||
self.assertEqual(R.dtype, mx.float64)
|
||||
|
||||
# SVD
|
||||
A = mx.array(
|
||||
[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=mx.float64
|
||||
)
|
||||
U, S, Vt = mx.linalg.svd(A)
|
||||
self.assertTrue(mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, A))
|
||||
|
||||
# Inverse
|
||||
A = mx.array([[1, 2, 3], [6, -5, 4], [-9, 8, 7]], dtype=mx.float64)
|
||||
A_inv = mx.linalg.inv(A)
|
||||
self.assertTrue(mx.allclose(A @ A_inv, mx.eye(A.shape[0])))
|
||||
|
||||
# Tri inv
|
||||
A = mx.array([[1, 0, 0], [6, -5, 0], [-9, 8, 7]], dtype=mx.float64)
|
||||
B = mx.array([[7, 0, 0], [3, -2, 0], [1, 8, 3]], dtype=mx.float64)
|
||||
AB = mx.stack([A, B])
|
||||
invs = mx.linalg.tri_inv(AB, upper=False)
|
||||
for M, M_inv in zip(AB, invs):
|
||||
self.assertTrue(mx.allclose(M @ M_inv, mx.eye(M.shape[0])))
|
||||
|
||||
# Cholesky
|
||||
sqrtA = mx.array(
|
||||
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], dtype=mx.float64
|
||||
)
|
||||
A = sqrtA.T @ sqrtA / 81
|
||||
L = mx.linalg.cholesky(A)
|
||||
U = mx.linalg.cholesky(A, upper=True)
|
||||
self.assertTrue(mx.allclose(L @ L.T, A))
|
||||
self.assertTrue(mx.allclose(U.T @ U, A))
|
||||
|
||||
# Psueod inverse
|
||||
A = mx.array([[1, 2, 3], [6, -5, 4], [-9, 8, 7]], dtype=mx.float64)
|
||||
A_plus = mx.linalg.pinv(A)
|
||||
self.assertTrue(mx.allclose(A @ A_plus @ A, A))
|
||||
|
||||
# Eigh
|
||||
def check_eigs_and_vecs(A_np, kwargs={}):
|
||||
A = mx.array(A_np, dtype=mx.float64)
|
||||
eig_vals, eig_vecs = mx.linalg.eigh(A, **kwargs)
|
||||
eig_vals_np, _ = np.linalg.eigh(A_np, **kwargs)
|
||||
self.assertTrue(np.allclose(eig_vals, eig_vals_np))
|
||||
self.assertTrue(
|
||||
mx.allclose(A @ eig_vecs, eig_vals[..., None, :] * eig_vecs)
|
||||
)
|
||||
|
||||
eig_vals_only = mx.linalg.eigvalsh(A, **kwargs)
|
||||
self.assertTrue(mx.allclose(eig_vals, eig_vals_only))
|
||||
|
||||
# Test a simple 2x2 symmetric matrix
|
||||
A_np = np.array([[1.0, 2.0], [2.0, 4.0]], dtype=np.float64)
|
||||
check_eigs_and_vecs(A_np)
|
||||
|
||||
# Test a larger random symmetric matrix
|
||||
n = 5
|
||||
np.random.seed(1)
|
||||
A_np = np.random.randn(n, n).astype(np.float64)
|
||||
A_np = (A_np + A_np.T) / 2
|
||||
check_eigs_and_vecs(A_np)
|
||||
|
||||
# Test with upper triangle
|
||||
check_eigs_and_vecs(A_np, {"UPLO": "U"})
|
||||
|
||||
# LU factorization
|
||||
# Test 3x3 matrix
|
||||
a = mx.array(
|
||||
[[3.0, 1.0, 2.0], [1.0, 8.0, 6.0], [9.0, 2.0, 5.0]], dtype=mx.float64
|
||||
)
|
||||
P, L, U = mx.linalg.lu(a)
|
||||
self.assertTrue(mx.allclose(L[P, :] @ U, a))
|
||||
|
||||
# Solve triangular
|
||||
# Test lower triangular matrix
|
||||
a = mx.array(
|
||||
[[4.0, 0.0, 0.0], [2.0, 3.0, 0.0], [1.0, -2.0, 5.0]], dtype=mx.float64
|
||||
)
|
||||
b = mx.array([8.0, 14.0, 3.0], dtype=mx.float64)
|
||||
|
||||
result = mx.linalg.solve_triangular(a, b, upper=False)
|
||||
expected = np.linalg.solve(np.array(a), np.array(b))
|
||||
self.assertTrue(np.allclose(result, expected))
|
||||
|
||||
# Test upper triangular matrix
|
||||
a = mx.array(
|
||||
[[3.0, 2.0, 1.0], [0.0, 5.0, 4.0], [0.0, 0.0, 6.0]], dtype=mx.float64
|
||||
)
|
||||
b = mx.array([13.0, 33.0, 18.0], dtype=mx.float64)
|
||||
|
||||
result = mx.linalg.solve_triangular(a, b, upper=True)
|
||||
expected = np.linalg.solve(np.array(a), np.array(b))
|
||||
self.assertTrue(np.allclose(result, expected))
|
||||
|
||||
def test_conversion(self):
|
||||
a = mx.array([1.0, 2.0], mx.float64)
|
||||
b = np.array(a)
|
||||
self.assertTrue(np.array_equal(a, b))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user