QR factorization (#310)

* add qr factorization

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
taher
2024-01-26 09:27:31 -08:00
committed by GitHub
parent 2463496471
commit 077c1ee64a
20 changed files with 322 additions and 19 deletions

View File

@@ -1,4 +1,4 @@
// Copyright © 2023 Apple Inc.
// Copyright © 2023-2024 Apple Inc.
#include "doctest/doctest.h"
@@ -248,3 +248,22 @@ TEST_CASE("[mlx.core.linalg.norm] string ord") {
array({14.28285686, 39.7617907}))
.item<bool>());
}
TEST_CASE("test QR factorization") {
// 0D and 1D throw
CHECK_THROWS(linalg::qr(array(0.0)));
CHECK_THROWS(linalg::qr(array({0.0, 1.0})));
// Unsupported types throw
CHECK_THROWS(linalg::qr(array({0, 1}, {1, 2})));
array A = array({{2., 3., 1., 2.}, {2, 2}});
auto [Q, R] = linalg::qr(A, Device::cpu);
auto out = matmul(Q, R);
CHECK(allclose(out, A).item<bool>());
out = matmul(Q, Q);
CHECK(allclose(out, eye(2), 1e-5, 1e-7).item<bool>());
CHECK(allclose(tril(R, -1), zeros_like(R)).item<bool>());
CHECK_EQ(Q.dtype(), float32);
CHECK_EQ(R.dtype(), float32);
}