mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 21:04:41 +08:00
Pinv (#875)
This commit is contained in:
@@ -347,4 +347,46 @@ TEST_CASE("test matrix cholesky") {
|
||||
.item<bool>());
|
||||
CHECK(allclose(matmul(transpose(U), U), A, /* rtol = */ 0, /* atol = */ 1e-6)
|
||||
.item<bool>());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("test matrix pseudo-inverse") {
|
||||
// 0D and 1D throw
|
||||
CHECK_THROWS(linalg::pinv(array(0.0), Device::cpu));
|
||||
CHECK_THROWS(linalg::pinv(array({0.0, 1.0}), Device::cpu));
|
||||
|
||||
// Unsupported types throw
|
||||
CHECK_THROWS(linalg::pinv(array({0, 1}, {1, 2}), Device::cpu));
|
||||
|
||||
{ // Square m == n
|
||||
const auto A = array({1.0, 2.0, 3.0, 4.0}, {2, 2});
|
||||
const auto A_pinv = linalg::pinv(A, Device::cpu);
|
||||
const auto A_again = matmul(matmul(A, A_pinv), A);
|
||||
CHECK(allclose(A_again, A).item<bool>());
|
||||
const auto A_pinv_again = matmul(matmul(A_pinv, A), A_pinv);
|
||||
CHECK(allclose(A_pinv_again, A_pinv).item<bool>());
|
||||
}
|
||||
{ // Rectangular matrix m < n
|
||||
const auto prng_key = random::key(42);
|
||||
const auto A = random::normal({4, 5}, prng_key);
|
||||
const auto A_pinv = linalg::pinv(A, Device::cpu);
|
||||
const auto zeros = zeros_like(A_pinv, Device::cpu);
|
||||
CHECK_FALSE(allclose(zeros, A_pinv, /* rtol = */ 0, /* atol = */ 1e-6)
|
||||
.item<bool>());
|
||||
const auto A_again = matmul(matmul(A, A_pinv), A);
|
||||
CHECK(allclose(A_again, A).item<bool>());
|
||||
const auto A_pinv_again = matmul(matmul(A_pinv, A), A_pinv);
|
||||
CHECK(allclose(A_pinv_again, A_pinv).item<bool>());
|
||||
}
|
||||
{ // Rectangular matrix m > n
|
||||
const auto prng_key = random::key(10);
|
||||
const auto A = random::normal({6, 5}, prng_key);
|
||||
const auto A_pinv = linalg::pinv(A, Device::cpu);
|
||||
const auto zeros2 = zeros_like(A_pinv, Device::cpu);
|
||||
CHECK_FALSE(allclose(zeros2, A_pinv, /* rtol = */ 0, /* atol = */ 1e-6)
|
||||
.item<bool>());
|
||||
const auto A_again = matmul(matmul(A, A_pinv), A);
|
||||
CHECK(allclose(A_again, A).item<bool>());
|
||||
const auto A_pinv_again = matmul(matmul(A_pinv, A), A_pinv);
|
||||
CHECK(allclose(A_pinv_again, A_pinv).item<bool>());
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user