CPU LU factorization and linear solvers (#1451)

* linalg solve backend

* nits

* more nits + fix

* luf primitive and lu, solve, and solve_triangular backends

* changes / nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Abe Leininger
2025-02-10 14:32:24 -06:00
committed by GitHub
parent 7df3f792a2
commit a5ededf1c3
12 changed files with 571 additions and 15 deletions

View File

@@ -59,6 +59,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/luf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp

88
mlx/backend/cpu/luf.cpp Normal file
View File

@@ -0,0 +1,88 @@
// Copyright © 2024 Apple Inc.
#include <cassert>
#include "mlx/allocator.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/lapack.h"
#include "mlx/primitives.h"
namespace mlx::core {
void lu_factor_impl(
const array& a,
array& lu,
array& pivots,
array& row_indices) {
int M = a.shape(-2);
int N = a.shape(-1);
// Copy a into lu and make it col contiguous
auto ndim = lu.ndim();
auto flags = lu.flags();
flags.col_contiguous = ndim == 2;
flags.row_contiguous = false;
flags.contiguous = true;
auto strides = lu.strides();
strides[ndim - 1] = M;
strides[ndim - 2] = 1;
lu.set_data(
allocator::malloc_or_wait(lu.nbytes()), lu.nbytes(), strides, flags);
copy_inplace(
a, lu, a.shape(), a.strides(), strides, 0, 0, CopyType::GeneralGeneral);
auto a_ptr = lu.data<float>();
pivots.set_data(allocator::malloc_or_wait(pivots.nbytes()));
row_indices.set_data(allocator::malloc_or_wait(row_indices.nbytes()));
auto pivots_ptr = pivots.data<uint32_t>();
auto row_indices_ptr = row_indices.data<uint32_t>();
int info;
size_t num_matrices = a.size() / (M * N);
for (size_t i = 0; i < num_matrices; ++i) {
// Compute LU factorization of A
MLX_LAPACK_FUNC(sgetrf)
(/* m */ &M,
/* n */ &N,
/* a */ a_ptr,
/* lda */ &M,
/* ipiv */ reinterpret_cast<int*>(pivots_ptr),
/* info */ &info);
if (info != 0) {
std::stringstream ss;
ss << "[LUF::eval_cpu] sgetrf_ failed with code " << info
<< ((info > 0) ? " because matrix is singular"
: " because argument had an illegal value");
throw std::runtime_error(ss.str());
}
// Subtract 1 to get 0-based index
for (int j = 0; j < pivots.shape(-1); ++j) {
pivots_ptr[j]--;
row_indices_ptr[j] = j;
}
for (int j = pivots.shape(-1) - 1; j >= 0; --j) {
auto piv = pivots_ptr[j];
auto t1 = row_indices_ptr[piv];
auto t2 = row_indices_ptr[j];
row_indices_ptr[j] = t1;
row_indices_ptr[piv] = t2;
}
// Advance pointers to the next matrix
a_ptr += M * N;
pivots_ptr += pivots.shape(-1);
row_indices_ptr += pivots.shape(-1);
}
}
void LUF::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
lu_factor_impl(inputs[0], outputs[0], outputs[1], outputs[2]);
}
} // namespace mlx::core

View File

@@ -554,6 +554,12 @@ void Eigh::eval_gpu(
throw std::runtime_error("[Eigvalsh::eval_gpu] Metal Eigh NYI.");
}
void LUF::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
throw std::runtime_error("[LUF::eval_gpu] Metal LU factorization NYI.");
}
void View::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
auto ibytes = size_of(in.dtype());

View File

@@ -81,6 +81,7 @@ NO_CPU(LogicalNot)
NO_CPU(LogicalAnd)
NO_CPU(LogicalOr)
NO_CPU(LogAddExp)
NO_CPU_MULTI(LUF)
NO_CPU(Matmul)
NO_CPU(Maximum)
NO_CPU(Minimum)

View File

@@ -81,6 +81,7 @@ NO_GPU(LogicalNot)
NO_GPU(LogicalAnd)
NO_GPU(LogicalOr)
NO_GPU(LogAddExp)
NO_GPU_MULTI(LUF)
NO_GPU(Matmul)
NO_GPU(Maximum)
NO_GPU(Minimum)