CPU LU factorization and linear solvers (#1451)

* linalg solve backend

* nits

* more nits + fix

* luf primitive and lu, solve, and solve_triangular backends

* changes / nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Abe Leininger
2025-02-10 14:32:24 -06:00
committed by GitHub
parent 7df3f792a2
commit a5ededf1c3
12 changed files with 571 additions and 15 deletions

View File

@@ -282,7 +282,7 @@ array inv(const array& a, StreamOrDevice s /* = {} */) {
array tri_inv(
const array& a,
bool upper /* = true */,
bool upper /* = false */,
StreamOrDevice s /* = {} */) {
return inv_impl(a, /*tri=*/true, upper, s);
}
@@ -519,4 +519,134 @@ std::pair<array, array> eigh(
return std::make_pair(out[0], out[1]);
}
void validate_lu(
const array& a,
const StreamOrDevice& stream,
const std::string& fname) {
check_cpu_stream(stream, fname);
if (a.dtype() != float32) {
std::ostringstream msg;
msg << fname << " Arrays must type float32. Received array "
<< "with type " << a.dtype() << ".";
throw std::invalid_argument(msg.str());
}
if (a.ndim() < 2) {
std::ostringstream msg;
msg << fname
<< " Arrays must have >= 2 dimensions. Received array "
"with "
<< a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
if (a.shape(-1) != a.shape(-2)) {
throw std::invalid_argument(fname + " Only defined for square matrices.");
}
}
std::vector<array> lu_helper(const array& a, StreamOrDevice s /* = {} */) {
int m = a.shape()[a.shape().size() - 2];
int n = a.shape()[a.shape().size() - 1];
Shape pivots_shape(a.shape().begin(), a.shape().end() - 2);
pivots_shape.push_back(std::min(m, n));
return array::make_arrays(
{a.shape(), pivots_shape, pivots_shape},
{a.dtype(), uint32, uint32},
std::make_shared<LUF>(to_stream(s)),
{astype(a, a.dtype(), s)});
}
std::vector<array> lu(const array& a, StreamOrDevice s /* = {} */) {
validate_lu(a, s, "[linalg::lu]");
auto out = lu_helper(a, s);
auto& LU = out[0];
auto& row_pivots = out[2];
int N = a.shape(-1);
auto L = add(tril(LU, /* k = */ -1, s), eye(N, s), s);
auto U = triu(LU, /* k = */ 0, s);
return {row_pivots, L, U};
}
std::pair<array, array> lu_factor(const array& a, StreamOrDevice s /* = {} */) {
validate_lu(a, s, "[linalg::lu_factor]");
auto out = lu_helper(a, s);
return std::make_pair(out[0], out[1]);
}
void validate_solve(
const array& a,
const array& b,
const StreamOrDevice& stream,
const std::string& fname) {
check_cpu_stream(stream, fname);
if (a.ndim() < 2) {
std::ostringstream msg;
msg << fname << " First input must have >= 2 dimensions. "
<< "Received array with " << a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
if (b.ndim() < 1) {
std::ostringstream msg;
msg << fname << " Second input must have >= 1 dimensions. "
<< "Received array with " << b.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
if (a.shape(-1) != a.shape(-2)) {
std::ostringstream msg;
msg << fname << " First input must be a square matrix. "
<< "Received array with shape " << a.shape() << ".";
throw std::invalid_argument(msg.str());
}
int lastDim = b.ndim() > 1 ? -2 : -1;
if (a.shape(-1) != b.shape(lastDim)) {
std::ostringstream msg;
msg << fname << " Last dimension of first input with shape " << a.shape()
<< " must match second to last dimension of"
<< " second input with shape " << b.shape() << ".";
throw std::invalid_argument(msg.str());
}
auto out_type = promote_types(a.dtype(), b.dtype());
if (out_type != float32) {
std::ostringstream msg;
msg << fname << " Input arrays must promote to float32. Received arrays "
<< "with type " << a.dtype() << " and " << b.dtype() << ".";
throw std::invalid_argument(msg.str());
}
}
array solve(const array& a, const array& b, StreamOrDevice s /* = {} */) {
validate_solve(a, b, s, "[linalg::solve]");
// P, L, U matrices
const auto luf = lu(a, s);
auto perm = argsort(luf[0], -1, s);
int take_axis = -1;
if (b.ndim() >= 2) {
perm = expand_dims(perm, -1, s);
take_axis -= 1;
}
auto pb = take_along_axis(b, perm, take_axis);
auto y = solve_triangular(luf[1], pb, /* upper = */ false, s);
return solve_triangular(luf[2], y, /* upper = */ true, s);
}
array solve_triangular(
const array& a,
const array& b,
bool upper /* = false */,
StreamOrDevice s /* = {} */) {
validate_solve(a, b, s, "[linalg::solve_triangular]");
auto a_inv = tri_inv(a, upper, s);
return matmul(a_inv, b, s);
}
} // namespace mlx::core::linalg