mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Implemented Cholesky on CPU (#1119)
This commit is contained in:
		| @@ -260,4 +260,33 @@ void init_linalg(nb::module_& parent_module) { | ||||
|         Returns: | ||||
|             array: ``ainv`` such that ``dot(a, ainv) = dot(ainv, a) = eye(a.shape[0])`` | ||||
|       )pbdoc"); | ||||
|   m.def( | ||||
|       "cholesky", | ||||
|       &cholesky, | ||||
|       "a"_a, | ||||
|       "upper"_a = false, | ||||
|       nb::kw_only(), | ||||
|       "stream"_a = nb::none(), | ||||
|       nb::sig( | ||||
|           "def cholesky(a: array, upper: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), | ||||
|       R"pbdoc( | ||||
|         Compute the Cholesky decomposition of a real symmetric positive semi-definite matrix. | ||||
|  | ||||
|         This function supports arrays with at least 2 dimensions. When the input | ||||
|         has more than two dimensions, the Cholesky decomposition is computed for each matrix | ||||
|         in the last two dimensions of ``a``. | ||||
|  | ||||
|         If the input matrix is not symmetric positive semi-definite, behaviour is undefined. | ||||
|  | ||||
|         Args: | ||||
|             a (array): Input array. | ||||
|             upper (bool, optional): If ``True``, return the upper triangular Cholesky factor. | ||||
|               If ``False``, return the lower triangular Cholesky factor. Default: ``False``. | ||||
|             stream (Stream, optional): Stream or device. Defaults to ``None`` | ||||
|               in which case the default stream of the default device is used. | ||||
|  | ||||
|         Returns: | ||||
|             array: if ``upper = False``, it returns a lower trinagular ``L``matrix such that ``dot(L, L.T) = a``. | ||||
|               If ``upper = True``, it returns an upper triangular ``U`` matrix such that ``dot(U.T, U) = a``. | ||||
|       )pbdoc"); | ||||
| } | ||||
|   | ||||
| @@ -150,6 +150,23 @@ class TestLinalg(mlx_tests.MLXTestCase): | ||||
|                 mx.allclose(M @ M_inv, mx.eye(M.shape[0]), rtol=0, atol=1e-5) | ||||
|             ) | ||||
|  | ||||
|     def test_cholesky(self): | ||||
|         sqrtA = mx.array( | ||||
|             [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], dtype=mx.float32 | ||||
|         ) | ||||
|         A = sqrtA.T @ sqrtA / 81 | ||||
|         L = mx.linalg.cholesky(A, stream=mx.cpu) | ||||
|         U = mx.linalg.cholesky(A, upper=True, stream=mx.cpu) | ||||
|         self.assertTrue(mx.allclose(L @ L.T, A, rtol=1e-5, atol=1e-7)) | ||||
|         self.assertTrue(mx.allclose(U.T @ U, A, rtol=1e-5, atol=1e-7)) | ||||
|  | ||||
|         # Multiple matrices | ||||
|         B = A + 1 / 9 | ||||
|         AB = mx.stack([A, B]) | ||||
|         Ls = mx.linalg.cholesky(AB, stream=mx.cpu) | ||||
|         for M, L in zip(AB, Ls): | ||||
|             self.assertTrue(mx.allclose(L @ L.T, M, rtol=1e-5, atol=1e-7)) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Luca Arnaboldi
					Luca Arnaboldi