CPU LU factorization and linear solvers (#1451)

* linalg solve backend

* nits

* more nits + fix

* luf primitive and lu, solve, and solve_triangular backends

* changes / nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Abe Leininger
2025-02-10 14:32:24 -06:00
committed by GitHub
parent 7df3f792a2
commit a5ededf1c3
12 changed files with 571 additions and 15 deletions

View File

@@ -74,6 +74,18 @@ array pinv(const array& a, StreamOrDevice s = {});
array cholesky_inv(const array& a, bool upper = false, StreamOrDevice s = {});
std::vector<array> lu(const array& a, StreamOrDevice s = {});
std::pair<array, array> lu_factor(const array& a, StreamOrDevice s = {});
array solve(const array& a, const array& b, StreamOrDevice s = {});
array solve_triangular(
const array& a,
const array& b,
bool upper = false,
StreamOrDevice s = {});
/**
* Compute the cross product of two arrays along the given axis.
*/