CPU LU factorization and linear solvers (#1451)

* linalg solve backend

* nits

* more nits + fix

* luf primitive and lu, solve, and solve_triangular backends

* changes / nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Abe Leininger
2025-02-10 14:32:24 -06:00
committed by GitHub
parent 7df3f792a2
commit a5ededf1c3
12 changed files with 571 additions and 15 deletions

View File

@@ -465,3 +465,95 @@ TEST_CASE("test matrix eigh") {
// Verify eigendecomposition
CHECK(allclose(matmul(A, eigvecs), eigvals * eigvecs).item<bool>());
}
TEST_CASE("test lu") {
// Test 2x2 matrix
array a = array({1., 2., 3., 4.}, {2, 2});
auto out = linalg::lu(a, Device::cpu);
auto L = take_along_axis(out[1], expand_dims(out[0], -1), -2);
array expected = matmul(L, out[2]);
CHECK(allclose(a, expected).item<bool>());
// Test 3x3 matrix
a = array({1., 2., 3., 4., 5., 6., 7., 8., 10.}, {3, 3});
out = linalg::lu(a, Device::cpu);
L = take_along_axis(out[1], expand_dims(out[0], -1), -2);
expected = matmul(L, out[2]);
CHECK(allclose(a, expected).item<bool>());
// Test batch dimension
a = broadcast_to(a, {3, 3, 3});
out = linalg::lu(a, Device::cpu);
L = take_along_axis(out[1], expand_dims(out[0], -1), -2);
expected = matmul(L, out[2]);
CHECK(allclose(a, expected).item<bool>());
}
TEST_CASE("test solve") {
// 0D and 1D throw
CHECK_THROWS(linalg::solve(array(0.), array(0.), Device::cpu));
CHECK_THROWS(linalg::solve(array({0.}), array({0.}), Device::cpu));
// Unsupported types throw
CHECK_THROWS(
linalg::solve(array({0, 1, 1, 2}, {2, 2}), array({1, 3}), Device::cpu));
// Non-square throws
array a = reshape(arange(6), {3, 2});
array b = reshape(arange(3), {3, 1});
CHECK_THROWS(linalg::solve(a, b, Device::cpu));
// Test 2x2 matrix with 1D rhs
a = array({2., 1., 1., 3.}, {2, 2});
b = array({8., 13.}, {2});
array result = linalg::solve(a, b, Device::cpu);
CHECK(allclose(matmul(a, result), b).item<bool>());
// Test 3x3 matrix
a = array({1., 2., 3., 4., 5., 6., 7., 8., 10.}, {3, 3});
b = array({6., 15., 25.}, {3, 1});
result = linalg::solve(a, b, Device::cpu);
CHECK(allclose(matmul(a, result), b).item<bool>());
// Test batch dimension
a = broadcast_to(a, {5, 3, 3});
b = broadcast_to(b, {5, 3, 1});
result = linalg::solve(a, b, Device::cpu);
CHECK(allclose(matmul(a, result), b).item<bool>());
// Test multi-column rhs
a = array({2., 1., 1., 1., 3., 2., 1., 0., 0.}, {3, 3});
b = array({4., 2., 5., 3., 6., 1.}, {3, 2});
result = linalg::solve(a, b, Device::cpu);
CHECK(allclose(matmul(a, result), b).item<bool>());
// Test batch multi-column rhs
a = broadcast_to(a, {5, 3, 3});
b = broadcast_to(b, {5, 3, 2});
result = linalg::solve(a, b, Device::cpu);
CHECK(allclose(matmul(a, result), b).item<bool>());
}
TEST_CASE("test solve_triangluar") {
// Test lower triangular matrix
array a = array({2., 0., 0., 3., 1., 0., 1., -1., 1.}, {3, 3});
array b = array({2., 5., 0.});
array result =
linalg::solve_triangular(a, b, /* upper = */ false, Device::cpu);
array expected = array({1., 2., 1.});
CHECK(allclose(expected, result).item<bool>());
// Test upper triangular matrix
a = array({2., 1., 3., 0., 4., 2., 0., 0., 1.}, {3, 3});
b = array({5., 14., 3.});
result = linalg::solve_triangular(a, b, /* upper = */ true, Device::cpu);
expected = array({-3., 2., 3.});
CHECK(allclose(expected, result).item<bool>());
}