mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-08 10:14:43 +08:00
Implemented Cholesky on CPU (#1119)
This commit is contained in:
@@ -322,3 +322,29 @@ TEST_CASE("test matrix inversion") {
|
||||
CHECK(allclose(matmul(A_inv, A), identity, /* rtol = */ 0, /* atol = */ 1e-6)
|
||||
.item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test matrix cholesky") {
|
||||
// 0D and 1D throw
|
||||
CHECK_THROWS(linalg::cholesky(array(0.0), /* upper = */ false, Device::cpu));
|
||||
CHECK_THROWS(
|
||||
linalg::cholesky(array({0.0, 1.0}), /* upper = */ false, Device::cpu));
|
||||
|
||||
// Unsupported types throw
|
||||
CHECK_THROWS(linalg::cholesky(
|
||||
array({0, 1}, {1, 2}), /* upper = */ false, Device::cpu));
|
||||
|
||||
// Non-square throws.
|
||||
CHECK_THROWS(linalg::cholesky(
|
||||
array({1, 2, 3, 4, 5, 6}, {2, 3}), /* upper = */ false, Device::cpu));
|
||||
|
||||
const auto prng_key = random::key(220398);
|
||||
const auto sqrtA = random::normal({5, 5}, prng_key);
|
||||
const auto A = matmul(sqrtA, transpose(sqrtA));
|
||||
const auto L = linalg::cholesky(A, /* upper = */ false, Device::cpu);
|
||||
const auto U = linalg::cholesky(A, /* upper = */ true, Device::cpu);
|
||||
|
||||
CHECK(allclose(matmul(L, transpose(L)), A, /* rtol = */ 0, /* atol = */ 1e-6)
|
||||
.item<bool>());
|
||||
CHECK(allclose(matmul(transpose(U), U), A, /* rtol = */ 0, /* atol = */ 1e-6)
|
||||
.item<bool>());
|
||||
}
|
Reference in New Issue
Block a user