mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	Double for lapack (#1904)
* double for lapack ops * add double support for lapack ops
This commit is contained in:
		| @@ -183,6 +183,115 @@ class TestDouble(mlx_tests.MLXTestCase): | ||||
|             c = a + b | ||||
|             self.assertEqual(c.dtype, mx.float64) | ||||
|  | ||||
|     def test_lapack(self): | ||||
|         with mx.stream(mx.cpu): | ||||
|             # QRF | ||||
|             A = mx.array([[2.0, 3.0], [1.0, 2.0]], dtype=mx.float64) | ||||
|             Q, R = mx.linalg.qr(A) | ||||
|             out = Q @ R | ||||
|             self.assertTrue(mx.allclose(out, A)) | ||||
|             out = Q.T @ Q | ||||
|             self.assertTrue(mx.allclose(out, mx.eye(2))) | ||||
|             self.assertTrue(mx.allclose(mx.tril(R, -1), mx.zeros_like(R))) | ||||
|             self.assertEqual(Q.dtype, mx.float64) | ||||
|             self.assertEqual(R.dtype, mx.float64) | ||||
|  | ||||
|             # SVD | ||||
|             A = mx.array( | ||||
|                 [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=mx.float64 | ||||
|             ) | ||||
|             U, S, Vt = mx.linalg.svd(A) | ||||
|             self.assertTrue(mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, A)) | ||||
|  | ||||
|             # Inverse | ||||
|             A = mx.array([[1, 2, 3], [6, -5, 4], [-9, 8, 7]], dtype=mx.float64) | ||||
|             A_inv = mx.linalg.inv(A) | ||||
|             self.assertTrue(mx.allclose(A @ A_inv, mx.eye(A.shape[0]))) | ||||
|  | ||||
|             # Tri inv | ||||
|             A = mx.array([[1, 0, 0], [6, -5, 0], [-9, 8, 7]], dtype=mx.float64) | ||||
|             B = mx.array([[7, 0, 0], [3, -2, 0], [1, 8, 3]], dtype=mx.float64) | ||||
|             AB = mx.stack([A, B]) | ||||
|             invs = mx.linalg.tri_inv(AB, upper=False) | ||||
|             for M, M_inv in zip(AB, invs): | ||||
|                 self.assertTrue(mx.allclose(M @ M_inv, mx.eye(M.shape[0]))) | ||||
|  | ||||
|             # Cholesky | ||||
|             sqrtA = mx.array( | ||||
|                 [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], dtype=mx.float64 | ||||
|             ) | ||||
|             A = sqrtA.T @ sqrtA / 81 | ||||
|             L = mx.linalg.cholesky(A) | ||||
|             U = mx.linalg.cholesky(A, upper=True) | ||||
|             self.assertTrue(mx.allclose(L @ L.T, A)) | ||||
|             self.assertTrue(mx.allclose(U.T @ U, A)) | ||||
|  | ||||
|             # Psueod inverse | ||||
|             A = mx.array([[1, 2, 3], [6, -5, 4], [-9, 8, 7]], dtype=mx.float64) | ||||
|             A_plus = mx.linalg.pinv(A) | ||||
|             self.assertTrue(mx.allclose(A @ A_plus @ A, A)) | ||||
|  | ||||
|             # Eigh | ||||
|             def check_eigs_and_vecs(A_np, kwargs={}): | ||||
|                 A = mx.array(A_np, dtype=mx.float64) | ||||
|                 eig_vals, eig_vecs = mx.linalg.eigh(A, **kwargs) | ||||
|                 eig_vals_np, _ = np.linalg.eigh(A_np, **kwargs) | ||||
|                 self.assertTrue(np.allclose(eig_vals, eig_vals_np)) | ||||
|                 self.assertTrue( | ||||
|                     mx.allclose(A @ eig_vecs, eig_vals[..., None, :] * eig_vecs) | ||||
|                 ) | ||||
|  | ||||
|                 eig_vals_only = mx.linalg.eigvalsh(A, **kwargs) | ||||
|                 self.assertTrue(mx.allclose(eig_vals, eig_vals_only)) | ||||
|  | ||||
|             # Test a simple 2x2 symmetric matrix | ||||
|             A_np = np.array([[1.0, 2.0], [2.0, 4.0]], dtype=np.float64) | ||||
|             check_eigs_and_vecs(A_np) | ||||
|  | ||||
|             # Test a larger random symmetric matrix | ||||
|             n = 5 | ||||
|             np.random.seed(1) | ||||
|             A_np = np.random.randn(n, n).astype(np.float64) | ||||
|             A_np = (A_np + A_np.T) / 2 | ||||
|             check_eigs_and_vecs(A_np) | ||||
|  | ||||
|             # Test with upper triangle | ||||
|             check_eigs_and_vecs(A_np, {"UPLO": "U"}) | ||||
|  | ||||
|             # LU factorization | ||||
|             # Test 3x3 matrix | ||||
|             a = mx.array( | ||||
|                 [[3.0, 1.0, 2.0], [1.0, 8.0, 6.0], [9.0, 2.0, 5.0]], dtype=mx.float64 | ||||
|             ) | ||||
|             P, L, U = mx.linalg.lu(a) | ||||
|             self.assertTrue(mx.allclose(L[P, :] @ U, a)) | ||||
|  | ||||
|             # Solve triangular | ||||
|             # Test lower triangular matrix | ||||
|             a = mx.array( | ||||
|                 [[4.0, 0.0, 0.0], [2.0, 3.0, 0.0], [1.0, -2.0, 5.0]], dtype=mx.float64 | ||||
|             ) | ||||
|             b = mx.array([8.0, 14.0, 3.0], dtype=mx.float64) | ||||
|  | ||||
|             result = mx.linalg.solve_triangular(a, b, upper=False) | ||||
|             expected = np.linalg.solve(np.array(a), np.array(b)) | ||||
|             self.assertTrue(np.allclose(result, expected)) | ||||
|  | ||||
|             # Test upper triangular matrix | ||||
|             a = mx.array( | ||||
|                 [[3.0, 2.0, 1.0], [0.0, 5.0, 4.0], [0.0, 0.0, 6.0]], dtype=mx.float64 | ||||
|             ) | ||||
|             b = mx.array([13.0, 33.0, 18.0], dtype=mx.float64) | ||||
|  | ||||
|             result = mx.linalg.solve_triangular(a, b, upper=True) | ||||
|             expected = np.linalg.solve(np.array(a), np.array(b)) | ||||
|             self.assertTrue(np.allclose(result, expected)) | ||||
|  | ||||
|     def test_conversion(self): | ||||
|         a = mx.array([1.0, 2.0], mx.float64) | ||||
|         b = np.array(a) | ||||
|         self.assertTrue(np.array_equal(a, b)) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun