mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-17 17:28:10 +08:00
CPU LU factorization and linear solvers (#1451)
* linalg solve backend * nits * more nits + fix * luf primitive and lu, solve, and solve_triangular backends * changes / nits --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -59,6 +59,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/luf.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
|
||||
|
88
mlx/backend/cpu/luf.cpp
Normal file
88
mlx/backend/cpu/luf.cpp
Normal file
@@ -0,0 +1,88 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/lapack.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void lu_factor_impl(
|
||||
const array& a,
|
||||
array& lu,
|
||||
array& pivots,
|
||||
array& row_indices) {
|
||||
int M = a.shape(-2);
|
||||
int N = a.shape(-1);
|
||||
|
||||
// Copy a into lu and make it col contiguous
|
||||
auto ndim = lu.ndim();
|
||||
auto flags = lu.flags();
|
||||
flags.col_contiguous = ndim == 2;
|
||||
flags.row_contiguous = false;
|
||||
flags.contiguous = true;
|
||||
auto strides = lu.strides();
|
||||
strides[ndim - 1] = M;
|
||||
strides[ndim - 2] = 1;
|
||||
lu.set_data(
|
||||
allocator::malloc_or_wait(lu.nbytes()), lu.nbytes(), strides, flags);
|
||||
copy_inplace(
|
||||
a, lu, a.shape(), a.strides(), strides, 0, 0, CopyType::GeneralGeneral);
|
||||
|
||||
auto a_ptr = lu.data<float>();
|
||||
|
||||
pivots.set_data(allocator::malloc_or_wait(pivots.nbytes()));
|
||||
row_indices.set_data(allocator::malloc_or_wait(row_indices.nbytes()));
|
||||
auto pivots_ptr = pivots.data<uint32_t>();
|
||||
auto row_indices_ptr = row_indices.data<uint32_t>();
|
||||
|
||||
int info;
|
||||
size_t num_matrices = a.size() / (M * N);
|
||||
for (size_t i = 0; i < num_matrices; ++i) {
|
||||
// Compute LU factorization of A
|
||||
MLX_LAPACK_FUNC(sgetrf)
|
||||
(/* m */ &M,
|
||||
/* n */ &N,
|
||||
/* a */ a_ptr,
|
||||
/* lda */ &M,
|
||||
/* ipiv */ reinterpret_cast<int*>(pivots_ptr),
|
||||
/* info */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "[LUF::eval_cpu] sgetrf_ failed with code " << info
|
||||
<< ((info > 0) ? " because matrix is singular"
|
||||
: " because argument had an illegal value");
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
// Subtract 1 to get 0-based index
|
||||
for (int j = 0; j < pivots.shape(-1); ++j) {
|
||||
pivots_ptr[j]--;
|
||||
row_indices_ptr[j] = j;
|
||||
}
|
||||
for (int j = pivots.shape(-1) - 1; j >= 0; --j) {
|
||||
auto piv = pivots_ptr[j];
|
||||
auto t1 = row_indices_ptr[piv];
|
||||
auto t2 = row_indices_ptr[j];
|
||||
row_indices_ptr[j] = t1;
|
||||
row_indices_ptr[piv] = t2;
|
||||
}
|
||||
|
||||
// Advance pointers to the next matrix
|
||||
a_ptr += M * N;
|
||||
pivots_ptr += pivots.shape(-1);
|
||||
row_indices_ptr += pivots.shape(-1);
|
||||
}
|
||||
}
|
||||
|
||||
void LUF::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
lu_factor_impl(inputs[0], outputs[0], outputs[1], outputs[2]);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -554,6 +554,12 @@ void Eigh::eval_gpu(
|
||||
throw std::runtime_error("[Eigvalsh::eval_gpu] Metal Eigh NYI.");
|
||||
}
|
||||
|
||||
void LUF::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
throw std::runtime_error("[LUF::eval_gpu] Metal LU factorization NYI.");
|
||||
}
|
||||
|
||||
void View::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& in = inputs[0];
|
||||
auto ibytes = size_of(in.dtype());
|
||||
|
@@ -81,6 +81,7 @@ NO_CPU(LogicalNot)
|
||||
NO_CPU(LogicalAnd)
|
||||
NO_CPU(LogicalOr)
|
||||
NO_CPU(LogAddExp)
|
||||
NO_CPU_MULTI(LUF)
|
||||
NO_CPU(Matmul)
|
||||
NO_CPU(Maximum)
|
||||
NO_CPU(Minimum)
|
||||
|
@@ -81,6 +81,7 @@ NO_GPU(LogicalNot)
|
||||
NO_GPU(LogicalAnd)
|
||||
NO_GPU(LogicalOr)
|
||||
NO_GPU(LogAddExp)
|
||||
NO_GPU_MULTI(LUF)
|
||||
NO_GPU(Matmul)
|
||||
NO_GPU(Maximum)
|
||||
NO_GPU(Minimum)
|
||||
|
132
mlx/linalg.cpp
132
mlx/linalg.cpp
@@ -282,7 +282,7 @@ array inv(const array& a, StreamOrDevice s /* = {} */) {
|
||||
|
||||
array tri_inv(
|
||||
const array& a,
|
||||
bool upper /* = true */,
|
||||
bool upper /* = false */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return inv_impl(a, /*tri=*/true, upper, s);
|
||||
}
|
||||
@@ -519,4 +519,134 @@ std::pair<array, array> eigh(
|
||||
return std::make_pair(out[0], out[1]);
|
||||
}
|
||||
|
||||
void validate_lu(
|
||||
const array& a,
|
||||
const StreamOrDevice& stream,
|
||||
const std::string& fname) {
|
||||
check_cpu_stream(stream, fname);
|
||||
if (a.dtype() != float32) {
|
||||
std::ostringstream msg;
|
||||
msg << fname << " Arrays must type float32. Received array "
|
||||
<< "with type " << a.dtype() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (a.ndim() < 2) {
|
||||
std::ostringstream msg;
|
||||
msg << fname
|
||||
<< " Arrays must have >= 2 dimensions. Received array "
|
||||
"with "
|
||||
<< a.ndim() << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (a.shape(-1) != a.shape(-2)) {
|
||||
throw std::invalid_argument(fname + " Only defined for square matrices.");
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<array> lu_helper(const array& a, StreamOrDevice s /* = {} */) {
|
||||
int m = a.shape()[a.shape().size() - 2];
|
||||
int n = a.shape()[a.shape().size() - 1];
|
||||
|
||||
Shape pivots_shape(a.shape().begin(), a.shape().end() - 2);
|
||||
pivots_shape.push_back(std::min(m, n));
|
||||
|
||||
return array::make_arrays(
|
||||
{a.shape(), pivots_shape, pivots_shape},
|
||||
{a.dtype(), uint32, uint32},
|
||||
std::make_shared<LUF>(to_stream(s)),
|
||||
{astype(a, a.dtype(), s)});
|
||||
}
|
||||
|
||||
std::vector<array> lu(const array& a, StreamOrDevice s /* = {} */) {
|
||||
validate_lu(a, s, "[linalg::lu]");
|
||||
|
||||
auto out = lu_helper(a, s);
|
||||
auto& LU = out[0];
|
||||
auto& row_pivots = out[2];
|
||||
|
||||
int N = a.shape(-1);
|
||||
auto L = add(tril(LU, /* k = */ -1, s), eye(N, s), s);
|
||||
auto U = triu(LU, /* k = */ 0, s);
|
||||
return {row_pivots, L, U};
|
||||
}
|
||||
|
||||
std::pair<array, array> lu_factor(const array& a, StreamOrDevice s /* = {} */) {
|
||||
validate_lu(a, s, "[linalg::lu_factor]");
|
||||
auto out = lu_helper(a, s);
|
||||
return std::make_pair(out[0], out[1]);
|
||||
}
|
||||
|
||||
void validate_solve(
|
||||
const array& a,
|
||||
const array& b,
|
||||
const StreamOrDevice& stream,
|
||||
const std::string& fname) {
|
||||
check_cpu_stream(stream, fname);
|
||||
if (a.ndim() < 2) {
|
||||
std::ostringstream msg;
|
||||
msg << fname << " First input must have >= 2 dimensions. "
|
||||
<< "Received array with " << a.ndim() << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (b.ndim() < 1) {
|
||||
std::ostringstream msg;
|
||||
msg << fname << " Second input must have >= 1 dimensions. "
|
||||
<< "Received array with " << b.ndim() << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (a.shape(-1) != a.shape(-2)) {
|
||||
std::ostringstream msg;
|
||||
msg << fname << " First input must be a square matrix. "
|
||||
<< "Received array with shape " << a.shape() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
int lastDim = b.ndim() > 1 ? -2 : -1;
|
||||
if (a.shape(-1) != b.shape(lastDim)) {
|
||||
std::ostringstream msg;
|
||||
msg << fname << " Last dimension of first input with shape " << a.shape()
|
||||
<< " must match second to last dimension of"
|
||||
<< " second input with shape " << b.shape() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto out_type = promote_types(a.dtype(), b.dtype());
|
||||
if (out_type != float32) {
|
||||
std::ostringstream msg;
|
||||
msg << fname << " Input arrays must promote to float32. Received arrays "
|
||||
<< "with type " << a.dtype() << " and " << b.dtype() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
array solve(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
validate_solve(a, b, s, "[linalg::solve]");
|
||||
|
||||
// P, L, U matrices
|
||||
const auto luf = lu(a, s);
|
||||
auto perm = argsort(luf[0], -1, s);
|
||||
int take_axis = -1;
|
||||
if (b.ndim() >= 2) {
|
||||
perm = expand_dims(perm, -1, s);
|
||||
take_axis -= 1;
|
||||
}
|
||||
auto pb = take_along_axis(b, perm, take_axis);
|
||||
auto y = solve_triangular(luf[1], pb, /* upper = */ false, s);
|
||||
return solve_triangular(luf[2], y, /* upper = */ true, s);
|
||||
}
|
||||
|
||||
array solve_triangular(
|
||||
const array& a,
|
||||
const array& b,
|
||||
bool upper /* = false */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
validate_solve(a, b, s, "[linalg::solve_triangular]");
|
||||
auto a_inv = tri_inv(a, upper, s);
|
||||
return matmul(a_inv, b, s);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::linalg
|
||||
|
12
mlx/linalg.h
12
mlx/linalg.h
@@ -74,6 +74,18 @@ array pinv(const array& a, StreamOrDevice s = {});
|
||||
|
||||
array cholesky_inv(const array& a, bool upper = false, StreamOrDevice s = {});
|
||||
|
||||
std::vector<array> lu(const array& a, StreamOrDevice s = {});
|
||||
|
||||
std::pair<array, array> lu_factor(const array& a, StreamOrDevice s = {});
|
||||
|
||||
array solve(const array& a, const array& b, StreamOrDevice s = {});
|
||||
|
||||
array solve_triangular(
|
||||
const array& a,
|
||||
const array& b,
|
||||
bool upper = false,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/**
|
||||
* Compute the cross product of two arrays along the given axis.
|
||||
*/
|
||||
|
@@ -2329,7 +2329,6 @@ class Eigh : public Primitive {
|
||||
: Primitive(stream),
|
||||
uplo_(std::move(uplo)),
|
||||
compute_eigenvectors_(compute_eigenvectors) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
@@ -2350,4 +2349,16 @@ class Eigh : public Primitive {
|
||||
bool compute_eigenvectors_;
|
||||
};
|
||||
|
||||
/* LU Factorization primitive. */
|
||||
class LUF : public Primitive {
|
||||
public:
|
||||
explicit LUF(Stream stream) : Primitive(stream) {}
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
DEFINE_PRINT(LUF)
|
||||
};
|
||||
|
||||
} // namespace mlx::core
|
||||
|
Reference in New Issue
Block a user