mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
CPU LU factorization and linear solvers (#1451)
* linalg solve backend * nits * more nits + fix * luf primitive and lu, solve, and solve_triangular backends * changes / nits --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
7df3f792a2
commit
a5ededf1c3
@ -5,8 +5,8 @@ Linear Algebra
|
|||||||
|
|
||||||
.. currentmodule:: mlx.core.linalg
|
.. currentmodule:: mlx.core.linalg
|
||||||
|
|
||||||
.. autosummary::
|
.. autosummary::
|
||||||
:toctree: _autosummary
|
:toctree: _autosummary
|
||||||
|
|
||||||
inv
|
inv
|
||||||
tri_inv
|
tri_inv
|
||||||
@ -18,3 +18,7 @@ Linear Algebra
|
|||||||
svd
|
svd
|
||||||
eigvalsh
|
eigvalsh
|
||||||
eigh
|
eigh
|
||||||
|
lu
|
||||||
|
lu_factor
|
||||||
|
solve
|
||||||
|
solve_triangular
|
||||||
|
@ -59,6 +59,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/luf.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
|
||||||
|
88
mlx/backend/cpu/luf.cpp
Normal file
88
mlx/backend/cpu/luf.cpp
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
#include "mlx/allocator.h"
|
||||||
|
#include "mlx/backend/cpu/copy.h"
|
||||||
|
#include "mlx/backend/cpu/lapack.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
void lu_factor_impl(
|
||||||
|
const array& a,
|
||||||
|
array& lu,
|
||||||
|
array& pivots,
|
||||||
|
array& row_indices) {
|
||||||
|
int M = a.shape(-2);
|
||||||
|
int N = a.shape(-1);
|
||||||
|
|
||||||
|
// Copy a into lu and make it col contiguous
|
||||||
|
auto ndim = lu.ndim();
|
||||||
|
auto flags = lu.flags();
|
||||||
|
flags.col_contiguous = ndim == 2;
|
||||||
|
flags.row_contiguous = false;
|
||||||
|
flags.contiguous = true;
|
||||||
|
auto strides = lu.strides();
|
||||||
|
strides[ndim - 1] = M;
|
||||||
|
strides[ndim - 2] = 1;
|
||||||
|
lu.set_data(
|
||||||
|
allocator::malloc_or_wait(lu.nbytes()), lu.nbytes(), strides, flags);
|
||||||
|
copy_inplace(
|
||||||
|
a, lu, a.shape(), a.strides(), strides, 0, 0, CopyType::GeneralGeneral);
|
||||||
|
|
||||||
|
auto a_ptr = lu.data<float>();
|
||||||
|
|
||||||
|
pivots.set_data(allocator::malloc_or_wait(pivots.nbytes()));
|
||||||
|
row_indices.set_data(allocator::malloc_or_wait(row_indices.nbytes()));
|
||||||
|
auto pivots_ptr = pivots.data<uint32_t>();
|
||||||
|
auto row_indices_ptr = row_indices.data<uint32_t>();
|
||||||
|
|
||||||
|
int info;
|
||||||
|
size_t num_matrices = a.size() / (M * N);
|
||||||
|
for (size_t i = 0; i < num_matrices; ++i) {
|
||||||
|
// Compute LU factorization of A
|
||||||
|
MLX_LAPACK_FUNC(sgetrf)
|
||||||
|
(/* m */ &M,
|
||||||
|
/* n */ &N,
|
||||||
|
/* a */ a_ptr,
|
||||||
|
/* lda */ &M,
|
||||||
|
/* ipiv */ reinterpret_cast<int*>(pivots_ptr),
|
||||||
|
/* info */ &info);
|
||||||
|
|
||||||
|
if (info != 0) {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "[LUF::eval_cpu] sgetrf_ failed with code " << info
|
||||||
|
<< ((info > 0) ? " because matrix is singular"
|
||||||
|
: " because argument had an illegal value");
|
||||||
|
throw std::runtime_error(ss.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subtract 1 to get 0-based index
|
||||||
|
for (int j = 0; j < pivots.shape(-1); ++j) {
|
||||||
|
pivots_ptr[j]--;
|
||||||
|
row_indices_ptr[j] = j;
|
||||||
|
}
|
||||||
|
for (int j = pivots.shape(-1) - 1; j >= 0; --j) {
|
||||||
|
auto piv = pivots_ptr[j];
|
||||||
|
auto t1 = row_indices_ptr[piv];
|
||||||
|
auto t2 = row_indices_ptr[j];
|
||||||
|
row_indices_ptr[j] = t1;
|
||||||
|
row_indices_ptr[piv] = t2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Advance pointers to the next matrix
|
||||||
|
a_ptr += M * N;
|
||||||
|
pivots_ptr += pivots.shape(-1);
|
||||||
|
row_indices_ptr += pivots.shape(-1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void LUF::eval_cpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
lu_factor_impl(inputs[0], outputs[0], outputs[1], outputs[2]);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -554,6 +554,12 @@ void Eigh::eval_gpu(
|
|||||||
throw std::runtime_error("[Eigvalsh::eval_gpu] Metal Eigh NYI.");
|
throw std::runtime_error("[Eigvalsh::eval_gpu] Metal Eigh NYI.");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void LUF::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
throw std::runtime_error("[LUF::eval_gpu] Metal LU factorization NYI.");
|
||||||
|
}
|
||||||
|
|
||||||
void View::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void View::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
auto ibytes = size_of(in.dtype());
|
auto ibytes = size_of(in.dtype());
|
||||||
|
@ -81,6 +81,7 @@ NO_CPU(LogicalNot)
|
|||||||
NO_CPU(LogicalAnd)
|
NO_CPU(LogicalAnd)
|
||||||
NO_CPU(LogicalOr)
|
NO_CPU(LogicalOr)
|
||||||
NO_CPU(LogAddExp)
|
NO_CPU(LogAddExp)
|
||||||
|
NO_CPU_MULTI(LUF)
|
||||||
NO_CPU(Matmul)
|
NO_CPU(Matmul)
|
||||||
NO_CPU(Maximum)
|
NO_CPU(Maximum)
|
||||||
NO_CPU(Minimum)
|
NO_CPU(Minimum)
|
||||||
|
@ -81,6 +81,7 @@ NO_GPU(LogicalNot)
|
|||||||
NO_GPU(LogicalAnd)
|
NO_GPU(LogicalAnd)
|
||||||
NO_GPU(LogicalOr)
|
NO_GPU(LogicalOr)
|
||||||
NO_GPU(LogAddExp)
|
NO_GPU(LogAddExp)
|
||||||
|
NO_GPU_MULTI(LUF)
|
||||||
NO_GPU(Matmul)
|
NO_GPU(Matmul)
|
||||||
NO_GPU(Maximum)
|
NO_GPU(Maximum)
|
||||||
NO_GPU(Minimum)
|
NO_GPU(Minimum)
|
||||||
|
132
mlx/linalg.cpp
132
mlx/linalg.cpp
@ -282,7 +282,7 @@ array inv(const array& a, StreamOrDevice s /* = {} */) {
|
|||||||
|
|
||||||
array tri_inv(
|
array tri_inv(
|
||||||
const array& a,
|
const array& a,
|
||||||
bool upper /* = true */,
|
bool upper /* = false */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
return inv_impl(a, /*tri=*/true, upper, s);
|
return inv_impl(a, /*tri=*/true, upper, s);
|
||||||
}
|
}
|
||||||
@ -519,4 +519,134 @@ std::pair<array, array> eigh(
|
|||||||
return std::make_pair(out[0], out[1]);
|
return std::make_pair(out[0], out[1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void validate_lu(
|
||||||
|
const array& a,
|
||||||
|
const StreamOrDevice& stream,
|
||||||
|
const std::string& fname) {
|
||||||
|
check_cpu_stream(stream, fname);
|
||||||
|
if (a.dtype() != float32) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << fname << " Arrays must type float32. Received array "
|
||||||
|
<< "with type " << a.dtype() << ".";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (a.ndim() < 2) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << fname
|
||||||
|
<< " Arrays must have >= 2 dimensions. Received array "
|
||||||
|
"with "
|
||||||
|
<< a.ndim() << " dimensions.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (a.shape(-1) != a.shape(-2)) {
|
||||||
|
throw std::invalid_argument(fname + " Only defined for square matrices.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> lu_helper(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
|
int m = a.shape()[a.shape().size() - 2];
|
||||||
|
int n = a.shape()[a.shape().size() - 1];
|
||||||
|
|
||||||
|
Shape pivots_shape(a.shape().begin(), a.shape().end() - 2);
|
||||||
|
pivots_shape.push_back(std::min(m, n));
|
||||||
|
|
||||||
|
return array::make_arrays(
|
||||||
|
{a.shape(), pivots_shape, pivots_shape},
|
||||||
|
{a.dtype(), uint32, uint32},
|
||||||
|
std::make_shared<LUF>(to_stream(s)),
|
||||||
|
{astype(a, a.dtype(), s)});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> lu(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
|
validate_lu(a, s, "[linalg::lu]");
|
||||||
|
|
||||||
|
auto out = lu_helper(a, s);
|
||||||
|
auto& LU = out[0];
|
||||||
|
auto& row_pivots = out[2];
|
||||||
|
|
||||||
|
int N = a.shape(-1);
|
||||||
|
auto L = add(tril(LU, /* k = */ -1, s), eye(N, s), s);
|
||||||
|
auto U = triu(LU, /* k = */ 0, s);
|
||||||
|
return {row_pivots, L, U};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<array, array> lu_factor(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
|
validate_lu(a, s, "[linalg::lu_factor]");
|
||||||
|
auto out = lu_helper(a, s);
|
||||||
|
return std::make_pair(out[0], out[1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
void validate_solve(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
const StreamOrDevice& stream,
|
||||||
|
const std::string& fname) {
|
||||||
|
check_cpu_stream(stream, fname);
|
||||||
|
if (a.ndim() < 2) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << fname << " First input must have >= 2 dimensions. "
|
||||||
|
<< "Received array with " << a.ndim() << " dimensions.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (b.ndim() < 1) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << fname << " Second input must have >= 1 dimensions. "
|
||||||
|
<< "Received array with " << b.ndim() << " dimensions.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (a.shape(-1) != a.shape(-2)) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << fname << " First input must be a square matrix. "
|
||||||
|
<< "Received array with shape " << a.shape() << ".";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
int lastDim = b.ndim() > 1 ? -2 : -1;
|
||||||
|
if (a.shape(-1) != b.shape(lastDim)) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << fname << " Last dimension of first input with shape " << a.shape()
|
||||||
|
<< " must match second to last dimension of"
|
||||||
|
<< " second input with shape " << b.shape() << ".";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto out_type = promote_types(a.dtype(), b.dtype());
|
||||||
|
if (out_type != float32) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << fname << " Input arrays must promote to float32. Received arrays "
|
||||||
|
<< "with type " << a.dtype() << " and " << b.dtype() << ".";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
array solve(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||||
|
validate_solve(a, b, s, "[linalg::solve]");
|
||||||
|
|
||||||
|
// P, L, U matrices
|
||||||
|
const auto luf = lu(a, s);
|
||||||
|
auto perm = argsort(luf[0], -1, s);
|
||||||
|
int take_axis = -1;
|
||||||
|
if (b.ndim() >= 2) {
|
||||||
|
perm = expand_dims(perm, -1, s);
|
||||||
|
take_axis -= 1;
|
||||||
|
}
|
||||||
|
auto pb = take_along_axis(b, perm, take_axis);
|
||||||
|
auto y = solve_triangular(luf[1], pb, /* upper = */ false, s);
|
||||||
|
return solve_triangular(luf[2], y, /* upper = */ true, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
array solve_triangular(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
bool upper /* = false */,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
validate_solve(a, b, s, "[linalg::solve_triangular]");
|
||||||
|
auto a_inv = tri_inv(a, upper, s);
|
||||||
|
return matmul(a_inv, b, s);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::linalg
|
} // namespace mlx::core::linalg
|
||||||
|
12
mlx/linalg.h
12
mlx/linalg.h
@ -74,6 +74,18 @@ array pinv(const array& a, StreamOrDevice s = {});
|
|||||||
|
|
||||||
array cholesky_inv(const array& a, bool upper = false, StreamOrDevice s = {});
|
array cholesky_inv(const array& a, bool upper = false, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
std::vector<array> lu(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
std::pair<array, array> lu_factor(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
array solve(const array& a, const array& b, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
array solve_triangular(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
bool upper = false,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compute the cross product of two arrays along the given axis.
|
* Compute the cross product of two arrays along the given axis.
|
||||||
*/
|
*/
|
||||||
|
@ -2329,7 +2329,6 @@ class Eigh : public Primitive {
|
|||||||
: Primitive(stream),
|
: Primitive(stream),
|
||||||
uplo_(std::move(uplo)),
|
uplo_(std::move(uplo)),
|
||||||
compute_eigenvectors_(compute_eigenvectors) {}
|
compute_eigenvectors_(compute_eigenvectors) {}
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
override;
|
override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
@ -2350,4 +2349,16 @@ class Eigh : public Primitive {
|
|||||||
bool compute_eigenvectors_;
|
bool compute_eigenvectors_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/* LU Factorization primitive. */
|
||||||
|
class LUF : public Primitive {
|
||||||
|
public:
|
||||||
|
explicit LUF(Stream stream) : Primitive(stream) {}
|
||||||
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
|
override;
|
||||||
|
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
|
override;
|
||||||
|
|
||||||
|
DEFINE_PRINT(LUF)
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -14,13 +14,6 @@ namespace mx = mlx::core;
|
|||||||
namespace nb = nanobind;
|
namespace nb = nanobind;
|
||||||
using namespace nb::literals;
|
using namespace nb::literals;
|
||||||
|
|
||||||
namespace {
|
|
||||||
nb::tuple svd_helper(const mx::array& a, mx::StreamOrDevice s /* = {} */) {
|
|
||||||
const auto result = mx::linalg::svd(a, s);
|
|
||||||
return nb::make_tuple(result.at(0), result.at(1), result.at(2));
|
|
||||||
}
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
void init_linalg(nb::module_& parent_module) {
|
void init_linalg(nb::module_& parent_module) {
|
||||||
auto m = parent_module.def_submodule(
|
auto m = parent_module.def_submodule(
|
||||||
"linalg", "mlx.core.linalg: linear algebra routines.");
|
"linalg", "mlx.core.linalg: linear algebra routines.");
|
||||||
@ -213,7 +206,10 @@ void init_linalg(nb::module_& parent_module) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"svd",
|
"svd",
|
||||||
&svd_helper,
|
[](const mx::array& a, mx::StreamOrDevice s /* = {} */) {
|
||||||
|
const auto result = mx::linalg::svd(a, s);
|
||||||
|
return nb::make_tuple(result.at(0), result.at(1), result.at(2));
|
||||||
|
},
|
||||||
"a"_a,
|
"a"_a,
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
@ -262,7 +258,7 @@ void init_linalg(nb::module_& parent_module) {
|
|||||||
"tri_inv",
|
"tri_inv",
|
||||||
&mx::linalg::tri_inv,
|
&mx::linalg::tri_inv,
|
||||||
"a"_a,
|
"a"_a,
|
||||||
"upper"_a,
|
"upper"_a = false,
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
@ -276,7 +272,7 @@ void init_linalg(nb::module_& parent_module) {
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
a (array): Input array.
|
a (array): Input array.
|
||||||
upper (array): Whether the array is upper or lower triangular. Defaults to ``False``.
|
upper (bool, optional): Whether the array is upper or lower triangular. Defaults to ``False``.
|
||||||
stream (Stream, optional): Stream or device. Defaults to ``None``
|
stream (Stream, optional): Stream or device. Defaults to ``None``
|
||||||
in which case the default stream of the default device is used.
|
in which case the default stream of the default device is used.
|
||||||
|
|
||||||
@ -441,7 +437,6 @@ void init_linalg(nb::module_& parent_module) {
|
|||||||
m.def(
|
m.def(
|
||||||
"eigh",
|
"eigh",
|
||||||
[](const mx::array& a, const std::string UPLO, mx::StreamOrDevice s) {
|
[](const mx::array& a, const std::string UPLO, mx::StreamOrDevice s) {
|
||||||
// TODO avoid cast?
|
|
||||||
auto result = mx::linalg::eigh(a, UPLO, s);
|
auto result = mx::linalg::eigh(a, UPLO, s);
|
||||||
return nb::make_tuple(result.first, result.second);
|
return nb::make_tuple(result.first, result.second);
|
||||||
},
|
},
|
||||||
@ -484,4 +479,102 @@ void init_linalg(nb::module_& parent_module) {
|
|||||||
array([[ 0.707107, -0.707107],
|
array([[ 0.707107, -0.707107],
|
||||||
[ 0.707107, 0.707107]], dtype=float32)
|
[ 0.707107, 0.707107]], dtype=float32)
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"lu",
|
||||||
|
[](const mx::array& a, mx::StreamOrDevice s /* = {} */) {
|
||||||
|
auto result = mx::linalg::lu(a, s);
|
||||||
|
return nb::make_tuple(result.at(0), result.at(1), result.at(2));
|
||||||
|
},
|
||||||
|
"a"_a,
|
||||||
|
nb::kw_only(),
|
||||||
|
"stream"_a = nb::none(),
|
||||||
|
nb::sig(
|
||||||
|
"def lu(a: array, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array, array]"),
|
||||||
|
R"pbdoc(
|
||||||
|
Compute the LU factorization of the given matrix ``A``.
|
||||||
|
|
||||||
|
Note, unlike the default behavior of ``scipy.linalg.lu``, the pivots
|
||||||
|
are indices. To reconstruct the input use ``L[P, :] @ U`` for 2
|
||||||
|
dimensions or ``mx.take_along_axis(L, P[..., None], axis=-2) @ U``
|
||||||
|
for more than 2 dimensions.
|
||||||
|
|
||||||
|
To construct the full permuation matrix do:
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
P = mx.put_along_axis(mx.zeros_like(L), p[..., None], mx.array(1.0), axis=-1)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a (array): Input array.
|
||||||
|
stream (Stream, optional): Stream or device. Defaults to ``None``
|
||||||
|
in which case the default stream of the default device is used.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple(array, array, array):
|
||||||
|
The ``p``, ``L``, and ``U`` arrays, such that ``A = L[P, :] @ U``
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"lu_factor",
|
||||||
|
&mx::linalg::lu_factor,
|
||||||
|
"a"_a,
|
||||||
|
nb::kw_only(),
|
||||||
|
"stream"_a = nb::none(),
|
||||||
|
nb::sig(
|
||||||
|
"def lu_factor(a: array, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array]"),
|
||||||
|
R"pbdoc(
|
||||||
|
Computes a compact representation of the LU factorization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a (array): Input array.
|
||||||
|
stream (Stream, optional): Stream or device. Defaults to ``None``
|
||||||
|
in which case the default stream of the default device is used.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple(array, array): The ``LU`` matrix and ``pivots`` array.
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"solve",
|
||||||
|
&mx::linalg::solve,
|
||||||
|
"a"_a,
|
||||||
|
"b"_a,
|
||||||
|
nb::kw_only(),
|
||||||
|
"stream"_a = nb::none(),
|
||||||
|
nb::sig(
|
||||||
|
"def solve(a: array, b: array, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
|
R"pbdoc(
|
||||||
|
Compute the solution to a system of linear equations ``AX = B``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a (array): Input array.
|
||||||
|
b (array): Input array.
|
||||||
|
stream (Stream, optional): Stream or device. Defaults to ``None``
|
||||||
|
in which case the default stream of the default device is used.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The unique solution to the system ``AX = B``.
|
||||||
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"solve_triangular",
|
||||||
|
&mx::linalg::solve_triangular,
|
||||||
|
"a"_a,
|
||||||
|
"b"_a,
|
||||||
|
nb::kw_only(),
|
||||||
|
"upper"_a = false,
|
||||||
|
"stream"_a = nb::none(),
|
||||||
|
nb::sig(
|
||||||
|
"def solve_triangular(a: array, b: array, *, upper: bool = False, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
|
R"pbdoc(
|
||||||
|
Computes the solution of a triangular system of linear equations ``AX = B``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a (array): Input array.
|
||||||
|
b (array): Input array.
|
||||||
|
upper (bool, optional): Whether the array is upper or lower
|
||||||
|
triangular. Default: ``False``.
|
||||||
|
stream (Stream, optional): Stream or device. Defaults to ``None``
|
||||||
|
in which case the default stream of the default device is used.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The unique solution to the system ``AX = B``.
|
||||||
|
)pbdoc");
|
||||||
}
|
}
|
||||||
|
@ -330,6 +330,123 @@ class TestLinalg(mlx_tests.MLXTestCase):
|
|||||||
mx.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
|
mx.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
|
||||||
) # Non-square matrix
|
) # Non-square matrix
|
||||||
|
|
||||||
|
def test_lu(self):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
mx.linalg.lu(mx.array(0.0), stream=mx.cpu)
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
mx.linalg.lu(mx.array([0.0, 1.0]), stream=mx.cpu)
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
mx.linalg.lu(mx.array([[0, 1], [1, 0]]), stream=mx.cpu)
|
||||||
|
|
||||||
|
# Test 3x3 matrix
|
||||||
|
a = mx.array([[3.0, 1.0, 2.0], [1.0, 8.0, 6.0], [9.0, 2.0, 5.0]])
|
||||||
|
P, L, U = mx.linalg.lu(a, stream=mx.cpu)
|
||||||
|
self.assertTrue(mx.allclose(L[P, :] @ U, a))
|
||||||
|
|
||||||
|
# Test batch dimension
|
||||||
|
a = mx.broadcast_to(a, (5, 5, 3, 3))
|
||||||
|
P, L, U = mx.linalg.lu(a, stream=mx.cpu)
|
||||||
|
L = mx.take_along_axis(L, P[..., None], axis=-2)
|
||||||
|
self.assertTrue(mx.allclose(L @ U, a))
|
||||||
|
|
||||||
|
def test_lu_factor(self):
|
||||||
|
mx.random.seed(7)
|
||||||
|
|
||||||
|
# Test 3x3 matrix
|
||||||
|
a = mx.random.uniform(shape=(5, 5))
|
||||||
|
LU, pivots = mx.linalg.lu_factor(a, stream=mx.cpu)
|
||||||
|
n = a.shape[-1]
|
||||||
|
|
||||||
|
pivots = pivots.tolist()
|
||||||
|
perm = list(range(n))
|
||||||
|
for i in range(len(pivots)):
|
||||||
|
perm[i], perm[pivots[i]] = perm[pivots[i]], perm[i]
|
||||||
|
|
||||||
|
L = mx.add(mx.tril(LU, k=-1), mx.eye(n))
|
||||||
|
U = mx.triu(LU)
|
||||||
|
self.assertTrue(mx.allclose(L @ U, a[perm, :]))
|
||||||
|
|
||||||
|
def test_solve(self):
|
||||||
|
mx.random.seed(7)
|
||||||
|
|
||||||
|
# Test 3x3 matrix with 1D rhs
|
||||||
|
a = mx.array([[3.0, 1.0, 2.0], [1.0, 8.0, 6.0], [9.0, 2.0, 5.0]])
|
||||||
|
b = mx.array([11.0, 35.0, 28.0])
|
||||||
|
|
||||||
|
result = mx.linalg.solve(a, b, stream=mx.cpu)
|
||||||
|
expected = np.linalg.solve(a, b)
|
||||||
|
self.assertTrue(np.allclose(result, expected))
|
||||||
|
|
||||||
|
# Test symmetric positive-definite matrix
|
||||||
|
N = 5
|
||||||
|
a = mx.random.uniform(shape=(N, N))
|
||||||
|
a = mx.matmul(a, a.T) + N * mx.eye(N)
|
||||||
|
b = mx.random.uniform(shape=(N, 1))
|
||||||
|
|
||||||
|
result = mx.linalg.solve(a, b, stream=mx.cpu)
|
||||||
|
expected = np.linalg.solve(a, b)
|
||||||
|
self.assertTrue(np.allclose(result, expected))
|
||||||
|
|
||||||
|
# Test batch dimension
|
||||||
|
a = mx.random.uniform(shape=(5, 5, 4, 4))
|
||||||
|
b = mx.random.uniform(shape=(5, 5, 4, 1))
|
||||||
|
|
||||||
|
result = mx.linalg.solve(a, b, stream=mx.cpu)
|
||||||
|
expected = np.linalg.solve(a, b)
|
||||||
|
self.assertTrue(np.allclose(result, expected, atol=1e-5))
|
||||||
|
|
||||||
|
# Test large matrix
|
||||||
|
N = 1000
|
||||||
|
a = mx.random.uniform(shape=(N, N))
|
||||||
|
b = mx.random.uniform(shape=(N, 1))
|
||||||
|
|
||||||
|
result = mx.linalg.solve(a, b, stream=mx.cpu)
|
||||||
|
expected = np.linalg.solve(a, b)
|
||||||
|
self.assertTrue(np.allclose(result, expected, atol=1e-3))
|
||||||
|
|
||||||
|
# Test multi-column rhs
|
||||||
|
a = mx.random.uniform(shape=(5, 5))
|
||||||
|
b = mx.random.uniform(shape=(5, 8))
|
||||||
|
|
||||||
|
result = mx.linalg.solve(a, b, stream=mx.cpu)
|
||||||
|
expected = np.linalg.solve(a, b)
|
||||||
|
self.assertTrue(np.allclose(result, expected))
|
||||||
|
|
||||||
|
# Test batched multi-column rhs
|
||||||
|
a = mx.broadcast_to(a, (3, 2, 5, 5))
|
||||||
|
b = mx.broadcast_to(b, (3, 1, 5, 8))
|
||||||
|
|
||||||
|
result = mx.linalg.solve(a, b, stream=mx.cpu)
|
||||||
|
expected = np.linalg.solve(a, b)
|
||||||
|
self.assertTrue(np.allclose(result, expected, rtol=1e-5, atol=1e-5))
|
||||||
|
|
||||||
|
def test_solve_triangular(self):
|
||||||
|
# Test lower triangular matrix
|
||||||
|
a = mx.array([[4.0, 0.0, 0.0], [2.0, 3.0, 0.0], [1.0, -2.0, 5.0]])
|
||||||
|
b = mx.array([8.0, 14.0, 3.0])
|
||||||
|
|
||||||
|
result = mx.linalg.solve_triangular(a, b, upper=False, stream=mx.cpu)
|
||||||
|
expected = np.linalg.solve(a, b)
|
||||||
|
self.assertTrue(np.allclose(result, expected))
|
||||||
|
|
||||||
|
# Test upper triangular matrix
|
||||||
|
a = mx.array([[3.0, 2.0, 1.0], [0.0, 5.0, 4.0], [0.0, 0.0, 6.0]])
|
||||||
|
b = mx.array([13.0, 33.0, 18.0])
|
||||||
|
|
||||||
|
result = mx.linalg.solve_triangular(a, b, upper=True, stream=mx.cpu)
|
||||||
|
expected = np.linalg.solve(a, b)
|
||||||
|
self.assertTrue(np.allclose(result, expected))
|
||||||
|
|
||||||
|
# Test batch multi-column rhs
|
||||||
|
a = mx.broadcast_to(a, (3, 4, 3, 3))
|
||||||
|
b = mx.broadcast_to(mx.expand_dims(b, -1), (3, 4, 3, 8))
|
||||||
|
|
||||||
|
result = mx.linalg.solve_triangular(a, b, upper=True, stream=mx.cpu)
|
||||||
|
expected = np.linalg.solve(a, b)
|
||||||
|
self.assertTrue(np.allclose(result, expected))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -465,3 +465,95 @@ TEST_CASE("test matrix eigh") {
|
|||||||
// Verify eigendecomposition
|
// Verify eigendecomposition
|
||||||
CHECK(allclose(matmul(A, eigvecs), eigvals * eigvecs).item<bool>());
|
CHECK(allclose(matmul(A, eigvecs), eigvals * eigvecs).item<bool>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test lu") {
|
||||||
|
// Test 2x2 matrix
|
||||||
|
array a = array({1., 2., 3., 4.}, {2, 2});
|
||||||
|
auto out = linalg::lu(a, Device::cpu);
|
||||||
|
auto L = take_along_axis(out[1], expand_dims(out[0], -1), -2);
|
||||||
|
array expected = matmul(L, out[2]);
|
||||||
|
CHECK(allclose(a, expected).item<bool>());
|
||||||
|
|
||||||
|
// Test 3x3 matrix
|
||||||
|
a = array({1., 2., 3., 4., 5., 6., 7., 8., 10.}, {3, 3});
|
||||||
|
out = linalg::lu(a, Device::cpu);
|
||||||
|
L = take_along_axis(out[1], expand_dims(out[0], -1), -2);
|
||||||
|
expected = matmul(L, out[2]);
|
||||||
|
CHECK(allclose(a, expected).item<bool>());
|
||||||
|
|
||||||
|
// Test batch dimension
|
||||||
|
a = broadcast_to(a, {3, 3, 3});
|
||||||
|
out = linalg::lu(a, Device::cpu);
|
||||||
|
L = take_along_axis(out[1], expand_dims(out[0], -1), -2);
|
||||||
|
expected = matmul(L, out[2]);
|
||||||
|
CHECK(allclose(a, expected).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test solve") {
|
||||||
|
// 0D and 1D throw
|
||||||
|
CHECK_THROWS(linalg::solve(array(0.), array(0.), Device::cpu));
|
||||||
|
CHECK_THROWS(linalg::solve(array({0.}), array({0.}), Device::cpu));
|
||||||
|
|
||||||
|
// Unsupported types throw
|
||||||
|
CHECK_THROWS(
|
||||||
|
linalg::solve(array({0, 1, 1, 2}, {2, 2}), array({1, 3}), Device::cpu));
|
||||||
|
|
||||||
|
// Non-square throws
|
||||||
|
array a = reshape(arange(6), {3, 2});
|
||||||
|
array b = reshape(arange(3), {3, 1});
|
||||||
|
CHECK_THROWS(linalg::solve(a, b, Device::cpu));
|
||||||
|
|
||||||
|
// Test 2x2 matrix with 1D rhs
|
||||||
|
a = array({2., 1., 1., 3.}, {2, 2});
|
||||||
|
b = array({8., 13.}, {2});
|
||||||
|
|
||||||
|
array result = linalg::solve(a, b, Device::cpu);
|
||||||
|
CHECK(allclose(matmul(a, result), b).item<bool>());
|
||||||
|
|
||||||
|
// Test 3x3 matrix
|
||||||
|
a = array({1., 2., 3., 4., 5., 6., 7., 8., 10.}, {3, 3});
|
||||||
|
b = array({6., 15., 25.}, {3, 1});
|
||||||
|
|
||||||
|
result = linalg::solve(a, b, Device::cpu);
|
||||||
|
CHECK(allclose(matmul(a, result), b).item<bool>());
|
||||||
|
|
||||||
|
// Test batch dimension
|
||||||
|
a = broadcast_to(a, {5, 3, 3});
|
||||||
|
b = broadcast_to(b, {5, 3, 1});
|
||||||
|
|
||||||
|
result = linalg::solve(a, b, Device::cpu);
|
||||||
|
CHECK(allclose(matmul(a, result), b).item<bool>());
|
||||||
|
|
||||||
|
// Test multi-column rhs
|
||||||
|
a = array({2., 1., 1., 1., 3., 2., 1., 0., 0.}, {3, 3});
|
||||||
|
b = array({4., 2., 5., 3., 6., 1.}, {3, 2});
|
||||||
|
|
||||||
|
result = linalg::solve(a, b, Device::cpu);
|
||||||
|
CHECK(allclose(matmul(a, result), b).item<bool>());
|
||||||
|
|
||||||
|
// Test batch multi-column rhs
|
||||||
|
a = broadcast_to(a, {5, 3, 3});
|
||||||
|
b = broadcast_to(b, {5, 3, 2});
|
||||||
|
|
||||||
|
result = linalg::solve(a, b, Device::cpu);
|
||||||
|
CHECK(allclose(matmul(a, result), b).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test solve_triangluar") {
|
||||||
|
// Test lower triangular matrix
|
||||||
|
array a = array({2., 0., 0., 3., 1., 0., 1., -1., 1.}, {3, 3});
|
||||||
|
array b = array({2., 5., 0.});
|
||||||
|
|
||||||
|
array result =
|
||||||
|
linalg::solve_triangular(a, b, /* upper = */ false, Device::cpu);
|
||||||
|
array expected = array({1., 2., 1.});
|
||||||
|
CHECK(allclose(expected, result).item<bool>());
|
||||||
|
|
||||||
|
// Test upper triangular matrix
|
||||||
|
a = array({2., 1., 3., 0., 4., 2., 0., 0., 1.}, {3, 3});
|
||||||
|
b = array({5., 14., 3.});
|
||||||
|
|
||||||
|
result = linalg::solve_triangular(a, b, /* upper = */ true, Device::cpu);
|
||||||
|
expected = array({-3., 2., 3.});
|
||||||
|
CHECK(allclose(expected, result).item<bool>());
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user