Add matrix inversion primitive (#822)

This commit is contained in:
nicolov
2024-03-15 14:34:36 +01:00
committed by GitHub
parent 19ec023256
commit eaba55c9bf
13 changed files with 204 additions and 4 deletions

View File

@@ -300,3 +300,25 @@ TEST_CASE("test SVD factorization") {
CHECK_EQ(S.dtype(), float32);
CHECK_EQ(Vt.dtype(), float32);
}
TEST_CASE("test matrix inversion") {
// 0D and 1D throw
CHECK_THROWS(linalg::inv(array(0.0), Device::cpu));
CHECK_THROWS(linalg::inv(array({0.0, 1.0}), Device::cpu));
// Unsupported types throw
CHECK_THROWS(linalg::inv(array({0, 1}, {1, 2}), Device::cpu));
// Non-square throws.
CHECK_THROWS(linalg::inv(array({1, 2, 3, 4, 5, 6}, {2, 3}), Device::cpu));
const auto prng_key = random::key(42);
const auto A = random::normal({5, 5}, prng_key);
const auto A_inv = linalg::inv(A, Device::cpu);
const auto identity = eye(A.shape(0));
CHECK(allclose(matmul(A, A_inv), identity, /* rtol = */ 0, /* atol = */ 1e-6)
.item<bool>());
CHECK(allclose(matmul(A_inv, A), identity, /* rtol = */ 0, /* atol = */ 1e-6)
.item<bool>());
}