Eigenvalues and eigenvectors (#1334)

* initial eigvalsh

* add compute_vectors

* add compute_vectors_

* return a pair

* add eigh to return only eigenvectors

* fixed typo

* merge merge Eighvalsh and Eigh into a single primitive

* use the same primate with the flag

* fix primatives

* use MULTI

* fix eval_gpu

* fix decleration

* rename EighPrimitive to Eigh

* tests

* tests

* fix rebase and format

* cleanup lapack

* format

* add cblas.h

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Kashif Rasul
2024-10-22 21:18:48 +02:00
committed by GitHub
parent c26208f67d
commit 3ddc07e936
23 changed files with 434 additions and 86 deletions

View File

@@ -435,3 +435,41 @@ TEST_CASE("test cross product") {
result = cross(a, b);
CHECK(allclose(result, expected).item<bool>());
}
TEST_CASE("test matrix eigh") {
// 0D and 1D throw
CHECK_THROWS(linalg::eigh(array(0.0)));
CHECK_THROWS(linalg::eigh(array({0.0, 1.0})));
CHECK_THROWS(linalg::eigvalsh(array(0.0)));
CHECK_THROWS(linalg::eigvalsh(array({0.0, 1.0})));
// Unsupported types throw
CHECK_THROWS(linalg::eigh(array({0, 1}, {1, 2})));
// Non-square throws
CHECK_THROWS(linalg::eigh(array({1, 2, 3, 4, 5, 6}, {2, 3})));
// Test a simple 2x2 symmetric matrix
array A = array({1.0, 2.0, 2.0, 4.0}, {2, 2}, float32);
auto [eigvals, eigvecs] = linalg::eigh(A, "L", Device::cpu);
// Expected eigenvalues
array expected_eigvals = array({0.0, 5.0});
CHECK(allclose(
eigvals,
expected_eigvals,
/* rtol = */ 1e-5,
/* atol = */ 1e-5)
.item<bool>());
// Verify orthogonality of eigenvectors
CHECK(allclose(
matmul(eigvecs, transpose(eigvecs)),
eye(2),
/* rtol = */ 1e-5,
/* atol = */ 1e-5)
.item<bool>());
// Verify eigendecomposition
CHECK(allclose(matmul(A, eigvecs), eigvals * eigvecs).item<bool>());
}