mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-29 14:58:11 +08:00
QR factorization (#310)
* add qr factorization --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include "doctest/doctest.h"
|
||||
|
||||
@@ -248,3 +248,22 @@ TEST_CASE("[mlx.core.linalg.norm] string ord") {
|
||||
array({14.28285686, 39.7617907}))
|
||||
.item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test QR factorization") {
|
||||
// 0D and 1D throw
|
||||
CHECK_THROWS(linalg::qr(array(0.0)));
|
||||
CHECK_THROWS(linalg::qr(array({0.0, 1.0})));
|
||||
|
||||
// Unsupported types throw
|
||||
CHECK_THROWS(linalg::qr(array({0, 1}, {1, 2})));
|
||||
|
||||
array A = array({{2., 3., 1., 2.}, {2, 2}});
|
||||
auto [Q, R] = linalg::qr(A, Device::cpu);
|
||||
auto out = matmul(Q, R);
|
||||
CHECK(allclose(out, A).item<bool>());
|
||||
out = matmul(Q, Q);
|
||||
CHECK(allclose(out, eye(2), 1e-5, 1e-7).item<bool>());
|
||||
CHECK(allclose(tril(R, -1), zeros_like(R)).item<bool>());
|
||||
CHECK_EQ(Q.dtype(), float32);
|
||||
CHECK_EQ(R.dtype(), float32);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user