mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 04:24:36 +08:00
Add matrix inversion primitive (#822)
This commit is contained in:
@@ -300,3 +300,25 @@ TEST_CASE("test SVD factorization") {
|
||||
CHECK_EQ(S.dtype(), float32);
|
||||
CHECK_EQ(Vt.dtype(), float32);
|
||||
}
|
||||
|
||||
TEST_CASE("test matrix inversion") {
|
||||
// 0D and 1D throw
|
||||
CHECK_THROWS(linalg::inv(array(0.0), Device::cpu));
|
||||
CHECK_THROWS(linalg::inv(array({0.0, 1.0}), Device::cpu));
|
||||
|
||||
// Unsupported types throw
|
||||
CHECK_THROWS(linalg::inv(array({0, 1}, {1, 2}), Device::cpu));
|
||||
|
||||
// Non-square throws.
|
||||
CHECK_THROWS(linalg::inv(array({1, 2, 3, 4, 5, 6}, {2, 3}), Device::cpu));
|
||||
|
||||
const auto prng_key = random::key(42);
|
||||
const auto A = random::normal({5, 5}, prng_key);
|
||||
const auto A_inv = linalg::inv(A, Device::cpu);
|
||||
const auto identity = eye(A.shape(0));
|
||||
|
||||
CHECK(allclose(matmul(A, A_inv), identity, /* rtol = */ 0, /* atol = */ 1e-6)
|
||||
.item<bool>());
|
||||
CHECK(allclose(matmul(A_inv, A), identity, /* rtol = */ 0, /* atol = */ 1e-6)
|
||||
.item<bool>());
|
||||
}
|
||||
|
Reference in New Issue
Block a user