mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Implemented Cholesky on CPU (#1119)
This commit is contained in:
		| @@ -322,3 +322,29 @@ TEST_CASE("test matrix inversion") { | ||||
|   CHECK(allclose(matmul(A_inv, A), identity, /* rtol = */ 0, /* atol = */ 1e-6) | ||||
|             .item<bool>()); | ||||
| } | ||||
|  | ||||
| TEST_CASE("test matrix cholesky") { | ||||
|   // 0D and 1D throw | ||||
|   CHECK_THROWS(linalg::cholesky(array(0.0), /* upper = */ false, Device::cpu)); | ||||
|   CHECK_THROWS( | ||||
|       linalg::cholesky(array({0.0, 1.0}), /* upper = */ false, Device::cpu)); | ||||
|  | ||||
|   // Unsupported types throw | ||||
|   CHECK_THROWS(linalg::cholesky( | ||||
|       array({0, 1}, {1, 2}), /* upper = */ false, Device::cpu)); | ||||
|  | ||||
|   // Non-square throws. | ||||
|   CHECK_THROWS(linalg::cholesky( | ||||
|       array({1, 2, 3, 4, 5, 6}, {2, 3}), /* upper = */ false, Device::cpu)); | ||||
|  | ||||
|   const auto prng_key = random::key(220398); | ||||
|   const auto sqrtA = random::normal({5, 5}, prng_key); | ||||
|   const auto A = matmul(sqrtA, transpose(sqrtA)); | ||||
|   const auto L = linalg::cholesky(A, /* upper = */ false, Device::cpu); | ||||
|   const auto U = linalg::cholesky(A, /* upper = */ true, Device::cpu); | ||||
|  | ||||
|   CHECK(allclose(matmul(L, transpose(L)), A, /* rtol = */ 0, /* atol = */ 1e-6) | ||||
|             .item<bool>()); | ||||
|   CHECK(allclose(matmul(transpose(U), U), A, /* rtol = */ 0, /* atol = */ 1e-6) | ||||
|             .item<bool>()); | ||||
| } | ||||
		Reference in New Issue
	
	Block a user
	 Luca Arnaboldi
					Luca Arnaboldi