mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
CPU LU factorization and linear solvers (#1451)
* linalg solve backend * nits * more nits + fix * luf primitive and lu, solve, and solve_triangular backends * changes / nits --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
132
mlx/linalg.cpp
132
mlx/linalg.cpp
@@ -282,7 +282,7 @@ array inv(const array& a, StreamOrDevice s /* = {} */) {
|
||||
|
||||
array tri_inv(
|
||||
const array& a,
|
||||
bool upper /* = true */,
|
||||
bool upper /* = false */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return inv_impl(a, /*tri=*/true, upper, s);
|
||||
}
|
||||
@@ -519,4 +519,134 @@ std::pair<array, array> eigh(
|
||||
return std::make_pair(out[0], out[1]);
|
||||
}
|
||||
|
||||
void validate_lu(
|
||||
const array& a,
|
||||
const StreamOrDevice& stream,
|
||||
const std::string& fname) {
|
||||
check_cpu_stream(stream, fname);
|
||||
if (a.dtype() != float32) {
|
||||
std::ostringstream msg;
|
||||
msg << fname << " Arrays must type float32. Received array "
|
||||
<< "with type " << a.dtype() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (a.ndim() < 2) {
|
||||
std::ostringstream msg;
|
||||
msg << fname
|
||||
<< " Arrays must have >= 2 dimensions. Received array "
|
||||
"with "
|
||||
<< a.ndim() << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (a.shape(-1) != a.shape(-2)) {
|
||||
throw std::invalid_argument(fname + " Only defined for square matrices.");
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<array> lu_helper(const array& a, StreamOrDevice s /* = {} */) {
|
||||
int m = a.shape()[a.shape().size() - 2];
|
||||
int n = a.shape()[a.shape().size() - 1];
|
||||
|
||||
Shape pivots_shape(a.shape().begin(), a.shape().end() - 2);
|
||||
pivots_shape.push_back(std::min(m, n));
|
||||
|
||||
return array::make_arrays(
|
||||
{a.shape(), pivots_shape, pivots_shape},
|
||||
{a.dtype(), uint32, uint32},
|
||||
std::make_shared<LUF>(to_stream(s)),
|
||||
{astype(a, a.dtype(), s)});
|
||||
}
|
||||
|
||||
std::vector<array> lu(const array& a, StreamOrDevice s /* = {} */) {
|
||||
validate_lu(a, s, "[linalg::lu]");
|
||||
|
||||
auto out = lu_helper(a, s);
|
||||
auto& LU = out[0];
|
||||
auto& row_pivots = out[2];
|
||||
|
||||
int N = a.shape(-1);
|
||||
auto L = add(tril(LU, /* k = */ -1, s), eye(N, s), s);
|
||||
auto U = triu(LU, /* k = */ 0, s);
|
||||
return {row_pivots, L, U};
|
||||
}
|
||||
|
||||
std::pair<array, array> lu_factor(const array& a, StreamOrDevice s /* = {} */) {
|
||||
validate_lu(a, s, "[linalg::lu_factor]");
|
||||
auto out = lu_helper(a, s);
|
||||
return std::make_pair(out[0], out[1]);
|
||||
}
|
||||
|
||||
void validate_solve(
|
||||
const array& a,
|
||||
const array& b,
|
||||
const StreamOrDevice& stream,
|
||||
const std::string& fname) {
|
||||
check_cpu_stream(stream, fname);
|
||||
if (a.ndim() < 2) {
|
||||
std::ostringstream msg;
|
||||
msg << fname << " First input must have >= 2 dimensions. "
|
||||
<< "Received array with " << a.ndim() << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (b.ndim() < 1) {
|
||||
std::ostringstream msg;
|
||||
msg << fname << " Second input must have >= 1 dimensions. "
|
||||
<< "Received array with " << b.ndim() << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (a.shape(-1) != a.shape(-2)) {
|
||||
std::ostringstream msg;
|
||||
msg << fname << " First input must be a square matrix. "
|
||||
<< "Received array with shape " << a.shape() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
int lastDim = b.ndim() > 1 ? -2 : -1;
|
||||
if (a.shape(-1) != b.shape(lastDim)) {
|
||||
std::ostringstream msg;
|
||||
msg << fname << " Last dimension of first input with shape " << a.shape()
|
||||
<< " must match second to last dimension of"
|
||||
<< " second input with shape " << b.shape() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto out_type = promote_types(a.dtype(), b.dtype());
|
||||
if (out_type != float32) {
|
||||
std::ostringstream msg;
|
||||
msg << fname << " Input arrays must promote to float32. Received arrays "
|
||||
<< "with type " << a.dtype() << " and " << b.dtype() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
array solve(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
validate_solve(a, b, s, "[linalg::solve]");
|
||||
|
||||
// P, L, U matrices
|
||||
const auto luf = lu(a, s);
|
||||
auto perm = argsort(luf[0], -1, s);
|
||||
int take_axis = -1;
|
||||
if (b.ndim() >= 2) {
|
||||
perm = expand_dims(perm, -1, s);
|
||||
take_axis -= 1;
|
||||
}
|
||||
auto pb = take_along_axis(b, perm, take_axis);
|
||||
auto y = solve_triangular(luf[1], pb, /* upper = */ false, s);
|
||||
return solve_triangular(luf[2], y, /* upper = */ true, s);
|
||||
}
|
||||
|
||||
array solve_triangular(
|
||||
const array& a,
|
||||
const array& b,
|
||||
bool upper /* = false */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
validate_solve(a, b, s, "[linalg::solve_triangular]");
|
||||
auto a_inv = tri_inv(a, upper, s);
|
||||
return matmul(a_inv, b, s);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::linalg
|
||||
|
||||
Reference in New Issue
Block a user