mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-20 01:11:44 +08:00
CPU LU factorization and linear solvers (#1451)
* linalg solve backend * nits * more nits + fix * luf primitive and lu, solve, and solve_triangular backends * changes / nits --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -465,3 +465,95 @@ TEST_CASE("test matrix eigh") {
|
||||
// Verify eigendecomposition
|
||||
CHECK(allclose(matmul(A, eigvecs), eigvals * eigvecs).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test lu") {
|
||||
// Test 2x2 matrix
|
||||
array a = array({1., 2., 3., 4.}, {2, 2});
|
||||
auto out = linalg::lu(a, Device::cpu);
|
||||
auto L = take_along_axis(out[1], expand_dims(out[0], -1), -2);
|
||||
array expected = matmul(L, out[2]);
|
||||
CHECK(allclose(a, expected).item<bool>());
|
||||
|
||||
// Test 3x3 matrix
|
||||
a = array({1., 2., 3., 4., 5., 6., 7., 8., 10.}, {3, 3});
|
||||
out = linalg::lu(a, Device::cpu);
|
||||
L = take_along_axis(out[1], expand_dims(out[0], -1), -2);
|
||||
expected = matmul(L, out[2]);
|
||||
CHECK(allclose(a, expected).item<bool>());
|
||||
|
||||
// Test batch dimension
|
||||
a = broadcast_to(a, {3, 3, 3});
|
||||
out = linalg::lu(a, Device::cpu);
|
||||
L = take_along_axis(out[1], expand_dims(out[0], -1), -2);
|
||||
expected = matmul(L, out[2]);
|
||||
CHECK(allclose(a, expected).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test solve") {
|
||||
// 0D and 1D throw
|
||||
CHECK_THROWS(linalg::solve(array(0.), array(0.), Device::cpu));
|
||||
CHECK_THROWS(linalg::solve(array({0.}), array({0.}), Device::cpu));
|
||||
|
||||
// Unsupported types throw
|
||||
CHECK_THROWS(
|
||||
linalg::solve(array({0, 1, 1, 2}, {2, 2}), array({1, 3}), Device::cpu));
|
||||
|
||||
// Non-square throws
|
||||
array a = reshape(arange(6), {3, 2});
|
||||
array b = reshape(arange(3), {3, 1});
|
||||
CHECK_THROWS(linalg::solve(a, b, Device::cpu));
|
||||
|
||||
// Test 2x2 matrix with 1D rhs
|
||||
a = array({2., 1., 1., 3.}, {2, 2});
|
||||
b = array({8., 13.}, {2});
|
||||
|
||||
array result = linalg::solve(a, b, Device::cpu);
|
||||
CHECK(allclose(matmul(a, result), b).item<bool>());
|
||||
|
||||
// Test 3x3 matrix
|
||||
a = array({1., 2., 3., 4., 5., 6., 7., 8., 10.}, {3, 3});
|
||||
b = array({6., 15., 25.}, {3, 1});
|
||||
|
||||
result = linalg::solve(a, b, Device::cpu);
|
||||
CHECK(allclose(matmul(a, result), b).item<bool>());
|
||||
|
||||
// Test batch dimension
|
||||
a = broadcast_to(a, {5, 3, 3});
|
||||
b = broadcast_to(b, {5, 3, 1});
|
||||
|
||||
result = linalg::solve(a, b, Device::cpu);
|
||||
CHECK(allclose(matmul(a, result), b).item<bool>());
|
||||
|
||||
// Test multi-column rhs
|
||||
a = array({2., 1., 1., 1., 3., 2., 1., 0., 0.}, {3, 3});
|
||||
b = array({4., 2., 5., 3., 6., 1.}, {3, 2});
|
||||
|
||||
result = linalg::solve(a, b, Device::cpu);
|
||||
CHECK(allclose(matmul(a, result), b).item<bool>());
|
||||
|
||||
// Test batch multi-column rhs
|
||||
a = broadcast_to(a, {5, 3, 3});
|
||||
b = broadcast_to(b, {5, 3, 2});
|
||||
|
||||
result = linalg::solve(a, b, Device::cpu);
|
||||
CHECK(allclose(matmul(a, result), b).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test solve_triangluar") {
|
||||
// Test lower triangular matrix
|
||||
array a = array({2., 0., 0., 3., 1., 0., 1., -1., 1.}, {3, 3});
|
||||
array b = array({2., 5., 0.});
|
||||
|
||||
array result =
|
||||
linalg::solve_triangular(a, b, /* upper = */ false, Device::cpu);
|
||||
array expected = array({1., 2., 1.});
|
||||
CHECK(allclose(expected, result).item<bool>());
|
||||
|
||||
// Test upper triangular matrix
|
||||
a = array({2., 1., 3., 0., 4., 2., 0., 0., 1.}, {3, 3});
|
||||
b = array({5., 14., 3.});
|
||||
|
||||
result = linalg::solve_triangular(a, b, /* upper = */ true, Device::cpu);
|
||||
expected = array({-3., 2., 3.});
|
||||
CHECK(allclose(expected, result).item<bool>());
|
||||
}
|
||||
|
Reference in New Issue
Block a user