mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	QR factorization (#310)
* add qr factorization --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
		| @@ -14,7 +14,7 @@ if (MLX_BUILD_METAL) | ||||
|   ) | ||||
| endif() | ||||
|  | ||||
| target_sources(tests PRIVATE  | ||||
| target_sources(tests PRIVATE | ||||
|   allocator_tests.cpp | ||||
|   array_tests.cpp | ||||
|   arg_reduce_tests.cpp | ||||
|   | ||||
| @@ -1,4 +1,4 @@ | ||||
| // Copyright © 2023 Apple Inc. | ||||
| // Copyright © 2023-2024 Apple Inc. | ||||
|  | ||||
| #include "doctest/doctest.h" | ||||
|  | ||||
| @@ -248,3 +248,22 @@ TEST_CASE("[mlx.core.linalg.norm] string ord") { | ||||
|             array({14.28285686, 39.7617907})) | ||||
|             .item<bool>()); | ||||
| } | ||||
|  | ||||
| TEST_CASE("test QR factorization") { | ||||
|   // 0D and 1D throw | ||||
|   CHECK_THROWS(linalg::qr(array(0.0))); | ||||
|   CHECK_THROWS(linalg::qr(array({0.0, 1.0}))); | ||||
|  | ||||
|   // Unsupported types throw | ||||
|   CHECK_THROWS(linalg::qr(array({0, 1}, {1, 2}))); | ||||
|  | ||||
|   array A = array({{2., 3., 1., 2.}, {2, 2}}); | ||||
|   auto [Q, R] = linalg::qr(A, Device::cpu); | ||||
|   auto out = matmul(Q, R); | ||||
|   CHECK(allclose(out, A).item<bool>()); | ||||
|   out = matmul(Q, Q); | ||||
|   CHECK(allclose(out, eye(2), 1e-5, 1e-7).item<bool>()); | ||||
|   CHECK(allclose(tril(R, -1), zeros_like(R)).item<bool>()); | ||||
|   CHECK_EQ(Q.dtype(), float32); | ||||
|   CHECK_EQ(R.dtype(), float32); | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 taher
					taher