mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
CPU LU factorization and linear solvers (#1451)
* linalg solve backend * nits * more nits + fix * luf primitive and lu, solve, and solve_triangular backends * changes / nits --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
7df3f792a2
commit
a5ededf1c3
@ -5,8 +5,8 @@ Linear Algebra
|
||||
|
||||
.. currentmodule:: mlx.core.linalg
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
inv
|
||||
tri_inv
|
||||
@ -18,3 +18,7 @@ Linear Algebra
|
||||
svd
|
||||
eigvalsh
|
||||
eigh
|
||||
lu
|
||||
lu_factor
|
||||
solve
|
||||
solve_triangular
|
||||
|
@ -59,6 +59,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/luf.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
|
||||
|
88
mlx/backend/cpu/luf.cpp
Normal file
88
mlx/backend/cpu/luf.cpp
Normal file
@ -0,0 +1,88 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/lapack.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void lu_factor_impl(
|
||||
const array& a,
|
||||
array& lu,
|
||||
array& pivots,
|
||||
array& row_indices) {
|
||||
int M = a.shape(-2);
|
||||
int N = a.shape(-1);
|
||||
|
||||
// Copy a into lu and make it col contiguous
|
||||
auto ndim = lu.ndim();
|
||||
auto flags = lu.flags();
|
||||
flags.col_contiguous = ndim == 2;
|
||||
flags.row_contiguous = false;
|
||||
flags.contiguous = true;
|
||||
auto strides = lu.strides();
|
||||
strides[ndim - 1] = M;
|
||||
strides[ndim - 2] = 1;
|
||||
lu.set_data(
|
||||
allocator::malloc_or_wait(lu.nbytes()), lu.nbytes(), strides, flags);
|
||||
copy_inplace(
|
||||
a, lu, a.shape(), a.strides(), strides, 0, 0, CopyType::GeneralGeneral);
|
||||
|
||||
auto a_ptr = lu.data<float>();
|
||||
|
||||
pivots.set_data(allocator::malloc_or_wait(pivots.nbytes()));
|
||||
row_indices.set_data(allocator::malloc_or_wait(row_indices.nbytes()));
|
||||
auto pivots_ptr = pivots.data<uint32_t>();
|
||||
auto row_indices_ptr = row_indices.data<uint32_t>();
|
||||
|
||||
int info;
|
||||
size_t num_matrices = a.size() / (M * N);
|
||||
for (size_t i = 0; i < num_matrices; ++i) {
|
||||
// Compute LU factorization of A
|
||||
MLX_LAPACK_FUNC(sgetrf)
|
||||
(/* m */ &M,
|
||||
/* n */ &N,
|
||||
/* a */ a_ptr,
|
||||
/* lda */ &M,
|
||||
/* ipiv */ reinterpret_cast<int*>(pivots_ptr),
|
||||
/* info */ &info);
|
||||
|
||||
if (info != 0) {
|
||||
std::stringstream ss;
|
||||
ss << "[LUF::eval_cpu] sgetrf_ failed with code " << info
|
||||
<< ((info > 0) ? " because matrix is singular"
|
||||
: " because argument had an illegal value");
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
// Subtract 1 to get 0-based index
|
||||
for (int j = 0; j < pivots.shape(-1); ++j) {
|
||||
pivots_ptr[j]--;
|
||||
row_indices_ptr[j] = j;
|
||||
}
|
||||
for (int j = pivots.shape(-1) - 1; j >= 0; --j) {
|
||||
auto piv = pivots_ptr[j];
|
||||
auto t1 = row_indices_ptr[piv];
|
||||
auto t2 = row_indices_ptr[j];
|
||||
row_indices_ptr[j] = t1;
|
||||
row_indices_ptr[piv] = t2;
|
||||
}
|
||||
|
||||
// Advance pointers to the next matrix
|
||||
a_ptr += M * N;
|
||||
pivots_ptr += pivots.shape(-1);
|
||||
row_indices_ptr += pivots.shape(-1);
|
||||
}
|
||||
}
|
||||
|
||||
void LUF::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
lu_factor_impl(inputs[0], outputs[0], outputs[1], outputs[2]);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@ -554,6 +554,12 @@ void Eigh::eval_gpu(
|
||||
throw std::runtime_error("[Eigvalsh::eval_gpu] Metal Eigh NYI.");
|
||||
}
|
||||
|
||||
void LUF::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
throw std::runtime_error("[LUF::eval_gpu] Metal LU factorization NYI.");
|
||||
}
|
||||
|
||||
void View::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& in = inputs[0];
|
||||
auto ibytes = size_of(in.dtype());
|
||||
|
@ -81,6 +81,7 @@ NO_CPU(LogicalNot)
|
||||
NO_CPU(LogicalAnd)
|
||||
NO_CPU(LogicalOr)
|
||||
NO_CPU(LogAddExp)
|
||||
NO_CPU_MULTI(LUF)
|
||||
NO_CPU(Matmul)
|
||||
NO_CPU(Maximum)
|
||||
NO_CPU(Minimum)
|
||||
|
@ -81,6 +81,7 @@ NO_GPU(LogicalNot)
|
||||
NO_GPU(LogicalAnd)
|
||||
NO_GPU(LogicalOr)
|
||||
NO_GPU(LogAddExp)
|
||||
NO_GPU_MULTI(LUF)
|
||||
NO_GPU(Matmul)
|
||||
NO_GPU(Maximum)
|
||||
NO_GPU(Minimum)
|
||||
|
132
mlx/linalg.cpp
132
mlx/linalg.cpp
@ -282,7 +282,7 @@ array inv(const array& a, StreamOrDevice s /* = {} */) {
|
||||
|
||||
array tri_inv(
|
||||
const array& a,
|
||||
bool upper /* = true */,
|
||||
bool upper /* = false */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return inv_impl(a, /*tri=*/true, upper, s);
|
||||
}
|
||||
@ -519,4 +519,134 @@ std::pair<array, array> eigh(
|
||||
return std::make_pair(out[0], out[1]);
|
||||
}
|
||||
|
||||
void validate_lu(
|
||||
const array& a,
|
||||
const StreamOrDevice& stream,
|
||||
const std::string& fname) {
|
||||
check_cpu_stream(stream, fname);
|
||||
if (a.dtype() != float32) {
|
||||
std::ostringstream msg;
|
||||
msg << fname << " Arrays must type float32. Received array "
|
||||
<< "with type " << a.dtype() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (a.ndim() < 2) {
|
||||
std::ostringstream msg;
|
||||
msg << fname
|
||||
<< " Arrays must have >= 2 dimensions. Received array "
|
||||
"with "
|
||||
<< a.ndim() << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (a.shape(-1) != a.shape(-2)) {
|
||||
throw std::invalid_argument(fname + " Only defined for square matrices.");
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<array> lu_helper(const array& a, StreamOrDevice s /* = {} */) {
|
||||
int m = a.shape()[a.shape().size() - 2];
|
||||
int n = a.shape()[a.shape().size() - 1];
|
||||
|
||||
Shape pivots_shape(a.shape().begin(), a.shape().end() - 2);
|
||||
pivots_shape.push_back(std::min(m, n));
|
||||
|
||||
return array::make_arrays(
|
||||
{a.shape(), pivots_shape, pivots_shape},
|
||||
{a.dtype(), uint32, uint32},
|
||||
std::make_shared<LUF>(to_stream(s)),
|
||||
{astype(a, a.dtype(), s)});
|
||||
}
|
||||
|
||||
std::vector<array> lu(const array& a, StreamOrDevice s /* = {} */) {
|
||||
validate_lu(a, s, "[linalg::lu]");
|
||||
|
||||
auto out = lu_helper(a, s);
|
||||
auto& LU = out[0];
|
||||
auto& row_pivots = out[2];
|
||||
|
||||
int N = a.shape(-1);
|
||||
auto L = add(tril(LU, /* k = */ -1, s), eye(N, s), s);
|
||||
auto U = triu(LU, /* k = */ 0, s);
|
||||
return {row_pivots, L, U};
|
||||
}
|
||||
|
||||
std::pair<array, array> lu_factor(const array& a, StreamOrDevice s /* = {} */) {
|
||||
validate_lu(a, s, "[linalg::lu_factor]");
|
||||
auto out = lu_helper(a, s);
|
||||
return std::make_pair(out[0], out[1]);
|
||||
}
|
||||
|
||||
void validate_solve(
|
||||
const array& a,
|
||||
const array& b,
|
||||
const StreamOrDevice& stream,
|
||||
const std::string& fname) {
|
||||
check_cpu_stream(stream, fname);
|
||||
if (a.ndim() < 2) {
|
||||
std::ostringstream msg;
|
||||
msg << fname << " First input must have >= 2 dimensions. "
|
||||
<< "Received array with " << a.ndim() << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (b.ndim() < 1) {
|
||||
std::ostringstream msg;
|
||||
msg << fname << " Second input must have >= 1 dimensions. "
|
||||
<< "Received array with " << b.ndim() << " dimensions.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (a.shape(-1) != a.shape(-2)) {
|
||||
std::ostringstream msg;
|
||||
msg << fname << " First input must be a square matrix. "
|
||||
<< "Received array with shape " << a.shape() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
int lastDim = b.ndim() > 1 ? -2 : -1;
|
||||
if (a.shape(-1) != b.shape(lastDim)) {
|
||||
std::ostringstream msg;
|
||||
msg << fname << " Last dimension of first input with shape " << a.shape()
|
||||
<< " must match second to last dimension of"
|
||||
<< " second input with shape " << b.shape() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto out_type = promote_types(a.dtype(), b.dtype());
|
||||
if (out_type != float32) {
|
||||
std::ostringstream msg;
|
||||
msg << fname << " Input arrays must promote to float32. Received arrays "
|
||||
<< "with type " << a.dtype() << " and " << b.dtype() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
array solve(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
validate_solve(a, b, s, "[linalg::solve]");
|
||||
|
||||
// P, L, U matrices
|
||||
const auto luf = lu(a, s);
|
||||
auto perm = argsort(luf[0], -1, s);
|
||||
int take_axis = -1;
|
||||
if (b.ndim() >= 2) {
|
||||
perm = expand_dims(perm, -1, s);
|
||||
take_axis -= 1;
|
||||
}
|
||||
auto pb = take_along_axis(b, perm, take_axis);
|
||||
auto y = solve_triangular(luf[1], pb, /* upper = */ false, s);
|
||||
return solve_triangular(luf[2], y, /* upper = */ true, s);
|
||||
}
|
||||
|
||||
array solve_triangular(
|
||||
const array& a,
|
||||
const array& b,
|
||||
bool upper /* = false */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
validate_solve(a, b, s, "[linalg::solve_triangular]");
|
||||
auto a_inv = tri_inv(a, upper, s);
|
||||
return matmul(a_inv, b, s);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::linalg
|
||||
|
12
mlx/linalg.h
12
mlx/linalg.h
@ -74,6 +74,18 @@ array pinv(const array& a, StreamOrDevice s = {});
|
||||
|
||||
array cholesky_inv(const array& a, bool upper = false, StreamOrDevice s = {});
|
||||
|
||||
std::vector<array> lu(const array& a, StreamOrDevice s = {});
|
||||
|
||||
std::pair<array, array> lu_factor(const array& a, StreamOrDevice s = {});
|
||||
|
||||
array solve(const array& a, const array& b, StreamOrDevice s = {});
|
||||
|
||||
array solve_triangular(
|
||||
const array& a,
|
||||
const array& b,
|
||||
bool upper = false,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/**
|
||||
* Compute the cross product of two arrays along the given axis.
|
||||
*/
|
||||
|
@ -2329,7 +2329,6 @@ class Eigh : public Primitive {
|
||||
: Primitive(stream),
|
||||
uplo_(std::move(uplo)),
|
||||
compute_eigenvectors_(compute_eigenvectors) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
@ -2350,4 +2349,16 @@ class Eigh : public Primitive {
|
||||
bool compute_eigenvectors_;
|
||||
};
|
||||
|
||||
/* LU Factorization primitive. */
|
||||
class LUF : public Primitive {
|
||||
public:
|
||||
explicit LUF(Stream stream) : Primitive(stream) {}
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
DEFINE_PRINT(LUF)
|
||||
};
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -14,13 +14,6 @@ namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
|
||||
namespace {
|
||||
nb::tuple svd_helper(const mx::array& a, mx::StreamOrDevice s /* = {} */) {
|
||||
const auto result = mx::linalg::svd(a, s);
|
||||
return nb::make_tuple(result.at(0), result.at(1), result.at(2));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void init_linalg(nb::module_& parent_module) {
|
||||
auto m = parent_module.def_submodule(
|
||||
"linalg", "mlx.core.linalg: linear algebra routines.");
|
||||
@ -213,7 +206,10 @@ void init_linalg(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"svd",
|
||||
&svd_helper,
|
||||
[](const mx::array& a, mx::StreamOrDevice s /* = {} */) {
|
||||
const auto result = mx::linalg::svd(a, s);
|
||||
return nb::make_tuple(result.at(0), result.at(1), result.at(2));
|
||||
},
|
||||
"a"_a,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
@ -262,7 +258,7 @@ void init_linalg(nb::module_& parent_module) {
|
||||
"tri_inv",
|
||||
&mx::linalg::tri_inv,
|
||||
"a"_a,
|
||||
"upper"_a,
|
||||
"upper"_a = false,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
@ -276,7 +272,7 @@ void init_linalg(nb::module_& parent_module) {
|
||||
|
||||
Args:
|
||||
a (array): Input array.
|
||||
upper (array): Whether the array is upper or lower triangular. Defaults to ``False``.
|
||||
upper (bool, optional): Whether the array is upper or lower triangular. Defaults to ``False``.
|
||||
stream (Stream, optional): Stream or device. Defaults to ``None``
|
||||
in which case the default stream of the default device is used.
|
||||
|
||||
@ -441,7 +437,6 @@ void init_linalg(nb::module_& parent_module) {
|
||||
m.def(
|
||||
"eigh",
|
||||
[](const mx::array& a, const std::string UPLO, mx::StreamOrDevice s) {
|
||||
// TODO avoid cast?
|
||||
auto result = mx::linalg::eigh(a, UPLO, s);
|
||||
return nb::make_tuple(result.first, result.second);
|
||||
},
|
||||
@ -484,4 +479,102 @@ void init_linalg(nb::module_& parent_module) {
|
||||
array([[ 0.707107, -0.707107],
|
||||
[ 0.707107, 0.707107]], dtype=float32)
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"lu",
|
||||
[](const mx::array& a, mx::StreamOrDevice s /* = {} */) {
|
||||
auto result = mx::linalg::lu(a, s);
|
||||
return nb::make_tuple(result.at(0), result.at(1), result.at(2));
|
||||
},
|
||||
"a"_a,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def lu(a: array, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array, array]"),
|
||||
R"pbdoc(
|
||||
Compute the LU factorization of the given matrix ``A``.
|
||||
|
||||
Note, unlike the default behavior of ``scipy.linalg.lu``, the pivots
|
||||
are indices. To reconstruct the input use ``L[P, :] @ U`` for 2
|
||||
dimensions or ``mx.take_along_axis(L, P[..., None], axis=-2) @ U``
|
||||
for more than 2 dimensions.
|
||||
|
||||
To construct the full permuation matrix do:
|
||||
|
||||
.. code-block::
|
||||
|
||||
P = mx.put_along_axis(mx.zeros_like(L), p[..., None], mx.array(1.0), axis=-1)
|
||||
|
||||
Args:
|
||||
a (array): Input array.
|
||||
stream (Stream, optional): Stream or device. Defaults to ``None``
|
||||
in which case the default stream of the default device is used.
|
||||
|
||||
Returns:
|
||||
tuple(array, array, array):
|
||||
The ``p``, ``L``, and ``U`` arrays, such that ``A = L[P, :] @ U``
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"lu_factor",
|
||||
&mx::linalg::lu_factor,
|
||||
"a"_a,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def lu_factor(a: array, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array]"),
|
||||
R"pbdoc(
|
||||
Computes a compact representation of the LU factorization.
|
||||
|
||||
Args:
|
||||
a (array): Input array.
|
||||
stream (Stream, optional): Stream or device. Defaults to ``None``
|
||||
in which case the default stream of the default device is used.
|
||||
|
||||
Returns:
|
||||
tuple(array, array): The ``LU`` matrix and ``pivots`` array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"solve",
|
||||
&mx::linalg::solve,
|
||||
"a"_a,
|
||||
"b"_a,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def solve(a: array, b: array, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Compute the solution to a system of linear equations ``AX = B``.
|
||||
|
||||
Args:
|
||||
a (array): Input array.
|
||||
b (array): Input array.
|
||||
stream (Stream, optional): Stream or device. Defaults to ``None``
|
||||
in which case the default stream of the default device is used.
|
||||
|
||||
Returns:
|
||||
array: The unique solution to the system ``AX = B``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"solve_triangular",
|
||||
&mx::linalg::solve_triangular,
|
||||
"a"_a,
|
||||
"b"_a,
|
||||
nb::kw_only(),
|
||||
"upper"_a = false,
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def solve_triangular(a: array, b: array, *, upper: bool = False, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Computes the solution of a triangular system of linear equations ``AX = B``.
|
||||
|
||||
Args:
|
||||
a (array): Input array.
|
||||
b (array): Input array.
|
||||
upper (bool, optional): Whether the array is upper or lower
|
||||
triangular. Default: ``False``.
|
||||
stream (Stream, optional): Stream or device. Defaults to ``None``
|
||||
in which case the default stream of the default device is used.
|
||||
|
||||
Returns:
|
||||
array: The unique solution to the system ``AX = B``.
|
||||
)pbdoc");
|
||||
}
|
||||
|
@ -330,6 +330,123 @@ class TestLinalg(mlx_tests.MLXTestCase):
|
||||
mx.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
|
||||
) # Non-square matrix
|
||||
|
||||
def test_lu(self):
|
||||
with self.assertRaises(ValueError):
|
||||
mx.linalg.lu(mx.array(0.0), stream=mx.cpu)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.linalg.lu(mx.array([0.0, 1.0]), stream=mx.cpu)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
mx.linalg.lu(mx.array([[0, 1], [1, 0]]), stream=mx.cpu)
|
||||
|
||||
# Test 3x3 matrix
|
||||
a = mx.array([[3.0, 1.0, 2.0], [1.0, 8.0, 6.0], [9.0, 2.0, 5.0]])
|
||||
P, L, U = mx.linalg.lu(a, stream=mx.cpu)
|
||||
self.assertTrue(mx.allclose(L[P, :] @ U, a))
|
||||
|
||||
# Test batch dimension
|
||||
a = mx.broadcast_to(a, (5, 5, 3, 3))
|
||||
P, L, U = mx.linalg.lu(a, stream=mx.cpu)
|
||||
L = mx.take_along_axis(L, P[..., None], axis=-2)
|
||||
self.assertTrue(mx.allclose(L @ U, a))
|
||||
|
||||
def test_lu_factor(self):
|
||||
mx.random.seed(7)
|
||||
|
||||
# Test 3x3 matrix
|
||||
a = mx.random.uniform(shape=(5, 5))
|
||||
LU, pivots = mx.linalg.lu_factor(a, stream=mx.cpu)
|
||||
n = a.shape[-1]
|
||||
|
||||
pivots = pivots.tolist()
|
||||
perm = list(range(n))
|
||||
for i in range(len(pivots)):
|
||||
perm[i], perm[pivots[i]] = perm[pivots[i]], perm[i]
|
||||
|
||||
L = mx.add(mx.tril(LU, k=-1), mx.eye(n))
|
||||
U = mx.triu(LU)
|
||||
self.assertTrue(mx.allclose(L @ U, a[perm, :]))
|
||||
|
||||
def test_solve(self):
|
||||
mx.random.seed(7)
|
||||
|
||||
# Test 3x3 matrix with 1D rhs
|
||||
a = mx.array([[3.0, 1.0, 2.0], [1.0, 8.0, 6.0], [9.0, 2.0, 5.0]])
|
||||
b = mx.array([11.0, 35.0, 28.0])
|
||||
|
||||
result = mx.linalg.solve(a, b, stream=mx.cpu)
|
||||
expected = np.linalg.solve(a, b)
|
||||
self.assertTrue(np.allclose(result, expected))
|
||||
|
||||
# Test symmetric positive-definite matrix
|
||||
N = 5
|
||||
a = mx.random.uniform(shape=(N, N))
|
||||
a = mx.matmul(a, a.T) + N * mx.eye(N)
|
||||
b = mx.random.uniform(shape=(N, 1))
|
||||
|
||||
result = mx.linalg.solve(a, b, stream=mx.cpu)
|
||||
expected = np.linalg.solve(a, b)
|
||||
self.assertTrue(np.allclose(result, expected))
|
||||
|
||||
# Test batch dimension
|
||||
a = mx.random.uniform(shape=(5, 5, 4, 4))
|
||||
b = mx.random.uniform(shape=(5, 5, 4, 1))
|
||||
|
||||
result = mx.linalg.solve(a, b, stream=mx.cpu)
|
||||
expected = np.linalg.solve(a, b)
|
||||
self.assertTrue(np.allclose(result, expected, atol=1e-5))
|
||||
|
||||
# Test large matrix
|
||||
N = 1000
|
||||
a = mx.random.uniform(shape=(N, N))
|
||||
b = mx.random.uniform(shape=(N, 1))
|
||||
|
||||
result = mx.linalg.solve(a, b, stream=mx.cpu)
|
||||
expected = np.linalg.solve(a, b)
|
||||
self.assertTrue(np.allclose(result, expected, atol=1e-3))
|
||||
|
||||
# Test multi-column rhs
|
||||
a = mx.random.uniform(shape=(5, 5))
|
||||
b = mx.random.uniform(shape=(5, 8))
|
||||
|
||||
result = mx.linalg.solve(a, b, stream=mx.cpu)
|
||||
expected = np.linalg.solve(a, b)
|
||||
self.assertTrue(np.allclose(result, expected))
|
||||
|
||||
# Test batched multi-column rhs
|
||||
a = mx.broadcast_to(a, (3, 2, 5, 5))
|
||||
b = mx.broadcast_to(b, (3, 1, 5, 8))
|
||||
|
||||
result = mx.linalg.solve(a, b, stream=mx.cpu)
|
||||
expected = np.linalg.solve(a, b)
|
||||
self.assertTrue(np.allclose(result, expected, rtol=1e-5, atol=1e-5))
|
||||
|
||||
def test_solve_triangular(self):
|
||||
# Test lower triangular matrix
|
||||
a = mx.array([[4.0, 0.0, 0.0], [2.0, 3.0, 0.0], [1.0, -2.0, 5.0]])
|
||||
b = mx.array([8.0, 14.0, 3.0])
|
||||
|
||||
result = mx.linalg.solve_triangular(a, b, upper=False, stream=mx.cpu)
|
||||
expected = np.linalg.solve(a, b)
|
||||
self.assertTrue(np.allclose(result, expected))
|
||||
|
||||
# Test upper triangular matrix
|
||||
a = mx.array([[3.0, 2.0, 1.0], [0.0, 5.0, 4.0], [0.0, 0.0, 6.0]])
|
||||
b = mx.array([13.0, 33.0, 18.0])
|
||||
|
||||
result = mx.linalg.solve_triangular(a, b, upper=True, stream=mx.cpu)
|
||||
expected = np.linalg.solve(a, b)
|
||||
self.assertTrue(np.allclose(result, expected))
|
||||
|
||||
# Test batch multi-column rhs
|
||||
a = mx.broadcast_to(a, (3, 4, 3, 3))
|
||||
b = mx.broadcast_to(mx.expand_dims(b, -1), (3, 4, 3, 8))
|
||||
|
||||
result = mx.linalg.solve_triangular(a, b, upper=True, stream=mx.cpu)
|
||||
expected = np.linalg.solve(a, b)
|
||||
self.assertTrue(np.allclose(result, expected))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -465,3 +465,95 @@ TEST_CASE("test matrix eigh") {
|
||||
// Verify eigendecomposition
|
||||
CHECK(allclose(matmul(A, eigvecs), eigvals * eigvecs).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test lu") {
|
||||
// Test 2x2 matrix
|
||||
array a = array({1., 2., 3., 4.}, {2, 2});
|
||||
auto out = linalg::lu(a, Device::cpu);
|
||||
auto L = take_along_axis(out[1], expand_dims(out[0], -1), -2);
|
||||
array expected = matmul(L, out[2]);
|
||||
CHECK(allclose(a, expected).item<bool>());
|
||||
|
||||
// Test 3x3 matrix
|
||||
a = array({1., 2., 3., 4., 5., 6., 7., 8., 10.}, {3, 3});
|
||||
out = linalg::lu(a, Device::cpu);
|
||||
L = take_along_axis(out[1], expand_dims(out[0], -1), -2);
|
||||
expected = matmul(L, out[2]);
|
||||
CHECK(allclose(a, expected).item<bool>());
|
||||
|
||||
// Test batch dimension
|
||||
a = broadcast_to(a, {3, 3, 3});
|
||||
out = linalg::lu(a, Device::cpu);
|
||||
L = take_along_axis(out[1], expand_dims(out[0], -1), -2);
|
||||
expected = matmul(L, out[2]);
|
||||
CHECK(allclose(a, expected).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test solve") {
|
||||
// 0D and 1D throw
|
||||
CHECK_THROWS(linalg::solve(array(0.), array(0.), Device::cpu));
|
||||
CHECK_THROWS(linalg::solve(array({0.}), array({0.}), Device::cpu));
|
||||
|
||||
// Unsupported types throw
|
||||
CHECK_THROWS(
|
||||
linalg::solve(array({0, 1, 1, 2}, {2, 2}), array({1, 3}), Device::cpu));
|
||||
|
||||
// Non-square throws
|
||||
array a = reshape(arange(6), {3, 2});
|
||||
array b = reshape(arange(3), {3, 1});
|
||||
CHECK_THROWS(linalg::solve(a, b, Device::cpu));
|
||||
|
||||
// Test 2x2 matrix with 1D rhs
|
||||
a = array({2., 1., 1., 3.}, {2, 2});
|
||||
b = array({8., 13.}, {2});
|
||||
|
||||
array result = linalg::solve(a, b, Device::cpu);
|
||||
CHECK(allclose(matmul(a, result), b).item<bool>());
|
||||
|
||||
// Test 3x3 matrix
|
||||
a = array({1., 2., 3., 4., 5., 6., 7., 8., 10.}, {3, 3});
|
||||
b = array({6., 15., 25.}, {3, 1});
|
||||
|
||||
result = linalg::solve(a, b, Device::cpu);
|
||||
CHECK(allclose(matmul(a, result), b).item<bool>());
|
||||
|
||||
// Test batch dimension
|
||||
a = broadcast_to(a, {5, 3, 3});
|
||||
b = broadcast_to(b, {5, 3, 1});
|
||||
|
||||
result = linalg::solve(a, b, Device::cpu);
|
||||
CHECK(allclose(matmul(a, result), b).item<bool>());
|
||||
|
||||
// Test multi-column rhs
|
||||
a = array({2., 1., 1., 1., 3., 2., 1., 0., 0.}, {3, 3});
|
||||
b = array({4., 2., 5., 3., 6., 1.}, {3, 2});
|
||||
|
||||
result = linalg::solve(a, b, Device::cpu);
|
||||
CHECK(allclose(matmul(a, result), b).item<bool>());
|
||||
|
||||
// Test batch multi-column rhs
|
||||
a = broadcast_to(a, {5, 3, 3});
|
||||
b = broadcast_to(b, {5, 3, 2});
|
||||
|
||||
result = linalg::solve(a, b, Device::cpu);
|
||||
CHECK(allclose(matmul(a, result), b).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test solve_triangluar") {
|
||||
// Test lower triangular matrix
|
||||
array a = array({2., 0., 0., 3., 1., 0., 1., -1., 1.}, {3, 3});
|
||||
array b = array({2., 5., 0.});
|
||||
|
||||
array result =
|
||||
linalg::solve_triangular(a, b, /* upper = */ false, Device::cpu);
|
||||
array expected = array({1., 2., 1.});
|
||||
CHECK(allclose(expected, result).item<bool>());
|
||||
|
||||
// Test upper triangular matrix
|
||||
a = array({2., 1., 3., 0., 4., 2., 0., 0., 1.}, {3, 3});
|
||||
b = array({5., 14., 3.});
|
||||
|
||||
result = linalg::solve_triangular(a, b, /* upper = */ true, Device::cpu);
|
||||
expected = array({-3., 2., 3.});
|
||||
CHECK(allclose(expected, result).item<bool>());
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user