CPU LU factorization and linear solvers (#1451)

* linalg solve backend

* nits

* more nits + fix

* luf primitive and lu, solve, and solve_triangular backends

* changes / nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Abe Leininger 2025-02-10 14:32:24 -06:00 committed by GitHub
parent 7df3f792a2
commit a5ededf1c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 571 additions and 15 deletions

View File

@ -5,8 +5,8 @@ Linear Algebra
.. currentmodule:: mlx.core.linalg
.. autosummary::
:toctree: _autosummary
.. autosummary::
:toctree: _autosummary
inv
tri_inv
@ -18,3 +18,7 @@ Linear Algebra
svd
eigvalsh
eigh
lu
lu_factor
solve
solve_triangular

View File

@ -59,6 +59,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/luf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp

88
mlx/backend/cpu/luf.cpp Normal file
View File

@ -0,0 +1,88 @@
// Copyright © 2024 Apple Inc.
#include <cassert>
#include "mlx/allocator.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/lapack.h"
#include "mlx/primitives.h"
namespace mlx::core {
void lu_factor_impl(
const array& a,
array& lu,
array& pivots,
array& row_indices) {
int M = a.shape(-2);
int N = a.shape(-1);
// Copy a into lu and make it col contiguous
auto ndim = lu.ndim();
auto flags = lu.flags();
flags.col_contiguous = ndim == 2;
flags.row_contiguous = false;
flags.contiguous = true;
auto strides = lu.strides();
strides[ndim - 1] = M;
strides[ndim - 2] = 1;
lu.set_data(
allocator::malloc_or_wait(lu.nbytes()), lu.nbytes(), strides, flags);
copy_inplace(
a, lu, a.shape(), a.strides(), strides, 0, 0, CopyType::GeneralGeneral);
auto a_ptr = lu.data<float>();
pivots.set_data(allocator::malloc_or_wait(pivots.nbytes()));
row_indices.set_data(allocator::malloc_or_wait(row_indices.nbytes()));
auto pivots_ptr = pivots.data<uint32_t>();
auto row_indices_ptr = row_indices.data<uint32_t>();
int info;
size_t num_matrices = a.size() / (M * N);
for (size_t i = 0; i < num_matrices; ++i) {
// Compute LU factorization of A
MLX_LAPACK_FUNC(sgetrf)
(/* m */ &M,
/* n */ &N,
/* a */ a_ptr,
/* lda */ &M,
/* ipiv */ reinterpret_cast<int*>(pivots_ptr),
/* info */ &info);
if (info != 0) {
std::stringstream ss;
ss << "[LUF::eval_cpu] sgetrf_ failed with code " << info
<< ((info > 0) ? " because matrix is singular"
: " because argument had an illegal value");
throw std::runtime_error(ss.str());
}
// Subtract 1 to get 0-based index
for (int j = 0; j < pivots.shape(-1); ++j) {
pivots_ptr[j]--;
row_indices_ptr[j] = j;
}
for (int j = pivots.shape(-1) - 1; j >= 0; --j) {
auto piv = pivots_ptr[j];
auto t1 = row_indices_ptr[piv];
auto t2 = row_indices_ptr[j];
row_indices_ptr[j] = t1;
row_indices_ptr[piv] = t2;
}
// Advance pointers to the next matrix
a_ptr += M * N;
pivots_ptr += pivots.shape(-1);
row_indices_ptr += pivots.shape(-1);
}
}
void LUF::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
lu_factor_impl(inputs[0], outputs[0], outputs[1], outputs[2]);
}
} // namespace mlx::core

View File

@ -554,6 +554,12 @@ void Eigh::eval_gpu(
throw std::runtime_error("[Eigvalsh::eval_gpu] Metal Eigh NYI.");
}
void LUF::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
throw std::runtime_error("[LUF::eval_gpu] Metal LU factorization NYI.");
}
void View::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
auto ibytes = size_of(in.dtype());

View File

@ -81,6 +81,7 @@ NO_CPU(LogicalNot)
NO_CPU(LogicalAnd)
NO_CPU(LogicalOr)
NO_CPU(LogAddExp)
NO_CPU_MULTI(LUF)
NO_CPU(Matmul)
NO_CPU(Maximum)
NO_CPU(Minimum)

View File

@ -81,6 +81,7 @@ NO_GPU(LogicalNot)
NO_GPU(LogicalAnd)
NO_GPU(LogicalOr)
NO_GPU(LogAddExp)
NO_GPU_MULTI(LUF)
NO_GPU(Matmul)
NO_GPU(Maximum)
NO_GPU(Minimum)

View File

@ -282,7 +282,7 @@ array inv(const array& a, StreamOrDevice s /* = {} */) {
array tri_inv(
const array& a,
bool upper /* = true */,
bool upper /* = false */,
StreamOrDevice s /* = {} */) {
return inv_impl(a, /*tri=*/true, upper, s);
}
@ -519,4 +519,134 @@ std::pair<array, array> eigh(
return std::make_pair(out[0], out[1]);
}
void validate_lu(
const array& a,
const StreamOrDevice& stream,
const std::string& fname) {
check_cpu_stream(stream, fname);
if (a.dtype() != float32) {
std::ostringstream msg;
msg << fname << " Arrays must type float32. Received array "
<< "with type " << a.dtype() << ".";
throw std::invalid_argument(msg.str());
}
if (a.ndim() < 2) {
std::ostringstream msg;
msg << fname
<< " Arrays must have >= 2 dimensions. Received array "
"with "
<< a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
if (a.shape(-1) != a.shape(-2)) {
throw std::invalid_argument(fname + " Only defined for square matrices.");
}
}
std::vector<array> lu_helper(const array& a, StreamOrDevice s /* = {} */) {
int m = a.shape()[a.shape().size() - 2];
int n = a.shape()[a.shape().size() - 1];
Shape pivots_shape(a.shape().begin(), a.shape().end() - 2);
pivots_shape.push_back(std::min(m, n));
return array::make_arrays(
{a.shape(), pivots_shape, pivots_shape},
{a.dtype(), uint32, uint32},
std::make_shared<LUF>(to_stream(s)),
{astype(a, a.dtype(), s)});
}
std::vector<array> lu(const array& a, StreamOrDevice s /* = {} */) {
validate_lu(a, s, "[linalg::lu]");
auto out = lu_helper(a, s);
auto& LU = out[0];
auto& row_pivots = out[2];
int N = a.shape(-1);
auto L = add(tril(LU, /* k = */ -1, s), eye(N, s), s);
auto U = triu(LU, /* k = */ 0, s);
return {row_pivots, L, U};
}
std::pair<array, array> lu_factor(const array& a, StreamOrDevice s /* = {} */) {
validate_lu(a, s, "[linalg::lu_factor]");
auto out = lu_helper(a, s);
return std::make_pair(out[0], out[1]);
}
void validate_solve(
const array& a,
const array& b,
const StreamOrDevice& stream,
const std::string& fname) {
check_cpu_stream(stream, fname);
if (a.ndim() < 2) {
std::ostringstream msg;
msg << fname << " First input must have >= 2 dimensions. "
<< "Received array with " << a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
if (b.ndim() < 1) {
std::ostringstream msg;
msg << fname << " Second input must have >= 1 dimensions. "
<< "Received array with " << b.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
if (a.shape(-1) != a.shape(-2)) {
std::ostringstream msg;
msg << fname << " First input must be a square matrix. "
<< "Received array with shape " << a.shape() << ".";
throw std::invalid_argument(msg.str());
}
int lastDim = b.ndim() > 1 ? -2 : -1;
if (a.shape(-1) != b.shape(lastDim)) {
std::ostringstream msg;
msg << fname << " Last dimension of first input with shape " << a.shape()
<< " must match second to last dimension of"
<< " second input with shape " << b.shape() << ".";
throw std::invalid_argument(msg.str());
}
auto out_type = promote_types(a.dtype(), b.dtype());
if (out_type != float32) {
std::ostringstream msg;
msg << fname << " Input arrays must promote to float32. Received arrays "
<< "with type " << a.dtype() << " and " << b.dtype() << ".";
throw std::invalid_argument(msg.str());
}
}
array solve(const array& a, const array& b, StreamOrDevice s /* = {} */) {
validate_solve(a, b, s, "[linalg::solve]");
// P, L, U matrices
const auto luf = lu(a, s);
auto perm = argsort(luf[0], -1, s);
int take_axis = -1;
if (b.ndim() >= 2) {
perm = expand_dims(perm, -1, s);
take_axis -= 1;
}
auto pb = take_along_axis(b, perm, take_axis);
auto y = solve_triangular(luf[1], pb, /* upper = */ false, s);
return solve_triangular(luf[2], y, /* upper = */ true, s);
}
array solve_triangular(
const array& a,
const array& b,
bool upper /* = false */,
StreamOrDevice s /* = {} */) {
validate_solve(a, b, s, "[linalg::solve_triangular]");
auto a_inv = tri_inv(a, upper, s);
return matmul(a_inv, b, s);
}
} // namespace mlx::core::linalg

View File

@ -74,6 +74,18 @@ array pinv(const array& a, StreamOrDevice s = {});
array cholesky_inv(const array& a, bool upper = false, StreamOrDevice s = {});
std::vector<array> lu(const array& a, StreamOrDevice s = {});
std::pair<array, array> lu_factor(const array& a, StreamOrDevice s = {});
array solve(const array& a, const array& b, StreamOrDevice s = {});
array solve_triangular(
const array& a,
const array& b,
bool upper = false,
StreamOrDevice s = {});
/**
* Compute the cross product of two arrays along the given axis.
*/

View File

@ -2329,7 +2329,6 @@ class Eigh : public Primitive {
: Primitive(stream),
uplo_(std::move(uplo)),
compute_eigenvectors_(compute_eigenvectors) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
@ -2350,4 +2349,16 @@ class Eigh : public Primitive {
bool compute_eigenvectors_;
};
/* LU Factorization primitive. */
class LUF : public Primitive {
public:
explicit LUF(Stream stream) : Primitive(stream) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
DEFINE_PRINT(LUF)
};
} // namespace mlx::core

View File

@ -14,13 +14,6 @@ namespace mx = mlx::core;
namespace nb = nanobind;
using namespace nb::literals;
namespace {
nb::tuple svd_helper(const mx::array& a, mx::StreamOrDevice s /* = {} */) {
const auto result = mx::linalg::svd(a, s);
return nb::make_tuple(result.at(0), result.at(1), result.at(2));
}
} // namespace
void init_linalg(nb::module_& parent_module) {
auto m = parent_module.def_submodule(
"linalg", "mlx.core.linalg: linear algebra routines.");
@ -213,7 +206,10 @@ void init_linalg(nb::module_& parent_module) {
)pbdoc");
m.def(
"svd",
&svd_helper,
[](const mx::array& a, mx::StreamOrDevice s /* = {} */) {
const auto result = mx::linalg::svd(a, s);
return nb::make_tuple(result.at(0), result.at(1), result.at(2));
},
"a"_a,
nb::kw_only(),
"stream"_a = nb::none(),
@ -262,7 +258,7 @@ void init_linalg(nb::module_& parent_module) {
"tri_inv",
&mx::linalg::tri_inv,
"a"_a,
"upper"_a,
"upper"_a = false,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
@ -276,7 +272,7 @@ void init_linalg(nb::module_& parent_module) {
Args:
a (array): Input array.
upper (array): Whether the array is upper or lower triangular. Defaults to ``False``.
upper (bool, optional): Whether the array is upper or lower triangular. Defaults to ``False``.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
@ -441,7 +437,6 @@ void init_linalg(nb::module_& parent_module) {
m.def(
"eigh",
[](const mx::array& a, const std::string UPLO, mx::StreamOrDevice s) {
// TODO avoid cast?
auto result = mx::linalg::eigh(a, UPLO, s);
return nb::make_tuple(result.first, result.second);
},
@ -484,4 +479,102 @@ void init_linalg(nb::module_& parent_module) {
array([[ 0.707107, -0.707107],
[ 0.707107, 0.707107]], dtype=float32)
)pbdoc");
m.def(
"lu",
[](const mx::array& a, mx::StreamOrDevice s /* = {} */) {
auto result = mx::linalg::lu(a, s);
return nb::make_tuple(result.at(0), result.at(1), result.at(2));
},
"a"_a,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def lu(a: array, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array, array]"),
R"pbdoc(
Compute the LU factorization of the given matrix ``A``.
Note, unlike the default behavior of ``scipy.linalg.lu``, the pivots
are indices. To reconstruct the input use ``L[P, :] @ U`` for 2
dimensions or ``mx.take_along_axis(L, P[..., None], axis=-2) @ U``
for more than 2 dimensions.
To construct the full permuation matrix do:
.. code-block::
P = mx.put_along_axis(mx.zeros_like(L), p[..., None], mx.array(1.0), axis=-1)
Args:
a (array): Input array.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
Returns:
tuple(array, array, array):
The ``p``, ``L``, and ``U`` arrays, such that ``A = L[P, :] @ U``
)pbdoc");
m.def(
"lu_factor",
&mx::linalg::lu_factor,
"a"_a,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def lu_factor(a: array, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array]"),
R"pbdoc(
Computes a compact representation of the LU factorization.
Args:
a (array): Input array.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
Returns:
tuple(array, array): The ``LU`` matrix and ``pivots`` array.
)pbdoc");
m.def(
"solve",
&mx::linalg::solve,
"a"_a,
"b"_a,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def solve(a: array, b: array, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Compute the solution to a system of linear equations ``AX = B``.
Args:
a (array): Input array.
b (array): Input array.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
Returns:
array: The unique solution to the system ``AX = B``.
)pbdoc");
m.def(
"solve_triangular",
&mx::linalg::solve_triangular,
"a"_a,
"b"_a,
nb::kw_only(),
"upper"_a = false,
"stream"_a = nb::none(),
nb::sig(
"def solve_triangular(a: array, b: array, *, upper: bool = False, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Computes the solution of a triangular system of linear equations ``AX = B``.
Args:
a (array): Input array.
b (array): Input array.
upper (bool, optional): Whether the array is upper or lower
triangular. Default: ``False``.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
Returns:
array: The unique solution to the system ``AX = B``.
)pbdoc");
}

View File

@ -330,6 +330,123 @@ class TestLinalg(mlx_tests.MLXTestCase):
mx.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
) # Non-square matrix
def test_lu(self):
with self.assertRaises(ValueError):
mx.linalg.lu(mx.array(0.0), stream=mx.cpu)
with self.assertRaises(ValueError):
mx.linalg.lu(mx.array([0.0, 1.0]), stream=mx.cpu)
with self.assertRaises(ValueError):
mx.linalg.lu(mx.array([[0, 1], [1, 0]]), stream=mx.cpu)
# Test 3x3 matrix
a = mx.array([[3.0, 1.0, 2.0], [1.0, 8.0, 6.0], [9.0, 2.0, 5.0]])
P, L, U = mx.linalg.lu(a, stream=mx.cpu)
self.assertTrue(mx.allclose(L[P, :] @ U, a))
# Test batch dimension
a = mx.broadcast_to(a, (5, 5, 3, 3))
P, L, U = mx.linalg.lu(a, stream=mx.cpu)
L = mx.take_along_axis(L, P[..., None], axis=-2)
self.assertTrue(mx.allclose(L @ U, a))
def test_lu_factor(self):
mx.random.seed(7)
# Test 3x3 matrix
a = mx.random.uniform(shape=(5, 5))
LU, pivots = mx.linalg.lu_factor(a, stream=mx.cpu)
n = a.shape[-1]
pivots = pivots.tolist()
perm = list(range(n))
for i in range(len(pivots)):
perm[i], perm[pivots[i]] = perm[pivots[i]], perm[i]
L = mx.add(mx.tril(LU, k=-1), mx.eye(n))
U = mx.triu(LU)
self.assertTrue(mx.allclose(L @ U, a[perm, :]))
def test_solve(self):
mx.random.seed(7)
# Test 3x3 matrix with 1D rhs
a = mx.array([[3.0, 1.0, 2.0], [1.0, 8.0, 6.0], [9.0, 2.0, 5.0]])
b = mx.array([11.0, 35.0, 28.0])
result = mx.linalg.solve(a, b, stream=mx.cpu)
expected = np.linalg.solve(a, b)
self.assertTrue(np.allclose(result, expected))
# Test symmetric positive-definite matrix
N = 5
a = mx.random.uniform(shape=(N, N))
a = mx.matmul(a, a.T) + N * mx.eye(N)
b = mx.random.uniform(shape=(N, 1))
result = mx.linalg.solve(a, b, stream=mx.cpu)
expected = np.linalg.solve(a, b)
self.assertTrue(np.allclose(result, expected))
# Test batch dimension
a = mx.random.uniform(shape=(5, 5, 4, 4))
b = mx.random.uniform(shape=(5, 5, 4, 1))
result = mx.linalg.solve(a, b, stream=mx.cpu)
expected = np.linalg.solve(a, b)
self.assertTrue(np.allclose(result, expected, atol=1e-5))
# Test large matrix
N = 1000
a = mx.random.uniform(shape=(N, N))
b = mx.random.uniform(shape=(N, 1))
result = mx.linalg.solve(a, b, stream=mx.cpu)
expected = np.linalg.solve(a, b)
self.assertTrue(np.allclose(result, expected, atol=1e-3))
# Test multi-column rhs
a = mx.random.uniform(shape=(5, 5))
b = mx.random.uniform(shape=(5, 8))
result = mx.linalg.solve(a, b, stream=mx.cpu)
expected = np.linalg.solve(a, b)
self.assertTrue(np.allclose(result, expected))
# Test batched multi-column rhs
a = mx.broadcast_to(a, (3, 2, 5, 5))
b = mx.broadcast_to(b, (3, 1, 5, 8))
result = mx.linalg.solve(a, b, stream=mx.cpu)
expected = np.linalg.solve(a, b)
self.assertTrue(np.allclose(result, expected, rtol=1e-5, atol=1e-5))
def test_solve_triangular(self):
# Test lower triangular matrix
a = mx.array([[4.0, 0.0, 0.0], [2.0, 3.0, 0.0], [1.0, -2.0, 5.0]])
b = mx.array([8.0, 14.0, 3.0])
result = mx.linalg.solve_triangular(a, b, upper=False, stream=mx.cpu)
expected = np.linalg.solve(a, b)
self.assertTrue(np.allclose(result, expected))
# Test upper triangular matrix
a = mx.array([[3.0, 2.0, 1.0], [0.0, 5.0, 4.0], [0.0, 0.0, 6.0]])
b = mx.array([13.0, 33.0, 18.0])
result = mx.linalg.solve_triangular(a, b, upper=True, stream=mx.cpu)
expected = np.linalg.solve(a, b)
self.assertTrue(np.allclose(result, expected))
# Test batch multi-column rhs
a = mx.broadcast_to(a, (3, 4, 3, 3))
b = mx.broadcast_to(mx.expand_dims(b, -1), (3, 4, 3, 8))
result = mx.linalg.solve_triangular(a, b, upper=True, stream=mx.cpu)
expected = np.linalg.solve(a, b)
self.assertTrue(np.allclose(result, expected))
if __name__ == "__main__":
unittest.main()

View File

@ -465,3 +465,95 @@ TEST_CASE("test matrix eigh") {
// Verify eigendecomposition
CHECK(allclose(matmul(A, eigvecs), eigvals * eigvecs).item<bool>());
}
TEST_CASE("test lu") {
// Test 2x2 matrix
array a = array({1., 2., 3., 4.}, {2, 2});
auto out = linalg::lu(a, Device::cpu);
auto L = take_along_axis(out[1], expand_dims(out[0], -1), -2);
array expected = matmul(L, out[2]);
CHECK(allclose(a, expected).item<bool>());
// Test 3x3 matrix
a = array({1., 2., 3., 4., 5., 6., 7., 8., 10.}, {3, 3});
out = linalg::lu(a, Device::cpu);
L = take_along_axis(out[1], expand_dims(out[0], -1), -2);
expected = matmul(L, out[2]);
CHECK(allclose(a, expected).item<bool>());
// Test batch dimension
a = broadcast_to(a, {3, 3, 3});
out = linalg::lu(a, Device::cpu);
L = take_along_axis(out[1], expand_dims(out[0], -1), -2);
expected = matmul(L, out[2]);
CHECK(allclose(a, expected).item<bool>());
}
TEST_CASE("test solve") {
// 0D and 1D throw
CHECK_THROWS(linalg::solve(array(0.), array(0.), Device::cpu));
CHECK_THROWS(linalg::solve(array({0.}), array({0.}), Device::cpu));
// Unsupported types throw
CHECK_THROWS(
linalg::solve(array({0, 1, 1, 2}, {2, 2}), array({1, 3}), Device::cpu));
// Non-square throws
array a = reshape(arange(6), {3, 2});
array b = reshape(arange(3), {3, 1});
CHECK_THROWS(linalg::solve(a, b, Device::cpu));
// Test 2x2 matrix with 1D rhs
a = array({2., 1., 1., 3.}, {2, 2});
b = array({8., 13.}, {2});
array result = linalg::solve(a, b, Device::cpu);
CHECK(allclose(matmul(a, result), b).item<bool>());
// Test 3x3 matrix
a = array({1., 2., 3., 4., 5., 6., 7., 8., 10.}, {3, 3});
b = array({6., 15., 25.}, {3, 1});
result = linalg::solve(a, b, Device::cpu);
CHECK(allclose(matmul(a, result), b).item<bool>());
// Test batch dimension
a = broadcast_to(a, {5, 3, 3});
b = broadcast_to(b, {5, 3, 1});
result = linalg::solve(a, b, Device::cpu);
CHECK(allclose(matmul(a, result), b).item<bool>());
// Test multi-column rhs
a = array({2., 1., 1., 1., 3., 2., 1., 0., 0.}, {3, 3});
b = array({4., 2., 5., 3., 6., 1.}, {3, 2});
result = linalg::solve(a, b, Device::cpu);
CHECK(allclose(matmul(a, result), b).item<bool>());
// Test batch multi-column rhs
a = broadcast_to(a, {5, 3, 3});
b = broadcast_to(b, {5, 3, 2});
result = linalg::solve(a, b, Device::cpu);
CHECK(allclose(matmul(a, result), b).item<bool>());
}
TEST_CASE("test solve_triangluar") {
// Test lower triangular matrix
array a = array({2., 0., 0., 3., 1., 0., 1., -1., 1.}, {3, 3});
array b = array({2., 5., 0.});
array result =
linalg::solve_triangular(a, b, /* upper = */ false, Device::cpu);
array expected = array({1., 2., 1.});
CHECK(allclose(expected, result).item<bool>());
// Test upper triangular matrix
a = array({2., 1., 3., 0., 4., 2., 0., 0., 1.}, {3, 3});
b = array({5., 14., 3.});
result = linalg::solve_triangular(a, b, /* upper = */ true, Device::cpu);
expected = array({-3., 2., 3.});
CHECK(allclose(expected, result).item<bool>());
}