This commit is contained in:
Aditya Dhulipala
2024-08-27 23:06:12 -07:00
committed by GitHub
parent e64349bbdd
commit e6b223df5f
5 changed files with 124 additions and 1 deletions

View File

@@ -347,4 +347,46 @@ TEST_CASE("test matrix cholesky") {
.item<bool>());
CHECK(allclose(matmul(transpose(U), U), A, /* rtol = */ 0, /* atol = */ 1e-6)
.item<bool>());
}
}
TEST_CASE("test matrix pseudo-inverse") {
// 0D and 1D throw
CHECK_THROWS(linalg::pinv(array(0.0), Device::cpu));
CHECK_THROWS(linalg::pinv(array({0.0, 1.0}), Device::cpu));
// Unsupported types throw
CHECK_THROWS(linalg::pinv(array({0, 1}, {1, 2}), Device::cpu));
{ // Square m == n
const auto A = array({1.0, 2.0, 3.0, 4.0}, {2, 2});
const auto A_pinv = linalg::pinv(A, Device::cpu);
const auto A_again = matmul(matmul(A, A_pinv), A);
CHECK(allclose(A_again, A).item<bool>());
const auto A_pinv_again = matmul(matmul(A_pinv, A), A_pinv);
CHECK(allclose(A_pinv_again, A_pinv).item<bool>());
}
{ // Rectangular matrix m < n
const auto prng_key = random::key(42);
const auto A = random::normal({4, 5}, prng_key);
const auto A_pinv = linalg::pinv(A, Device::cpu);
const auto zeros = zeros_like(A_pinv, Device::cpu);
CHECK_FALSE(allclose(zeros, A_pinv, /* rtol = */ 0, /* atol = */ 1e-6)
.item<bool>());
const auto A_again = matmul(matmul(A, A_pinv), A);
CHECK(allclose(A_again, A).item<bool>());
const auto A_pinv_again = matmul(matmul(A_pinv, A), A_pinv);
CHECK(allclose(A_pinv_again, A_pinv).item<bool>());
}
{ // Rectangular matrix m > n
const auto prng_key = random::key(10);
const auto A = random::normal({6, 5}, prng_key);
const auto A_pinv = linalg::pinv(A, Device::cpu);
const auto zeros2 = zeros_like(A_pinv, Device::cpu);
CHECK_FALSE(allclose(zeros2, A_pinv, /* rtol = */ 0, /* atol = */ 1e-6)
.item<bool>());
const auto A_again = matmul(matmul(A, A_pinv), A);
CHECK(allclose(A_again, A).item<bool>());
const auto A_pinv_again = matmul(matmul(A_pinv, A), A_pinv);
CHECK(allclose(A_pinv_again, A_pinv).item<bool>());
}
}