diff --git a/docs/src/python/linalg.rst b/docs/src/python/linalg.rst index f6c51ed0b..769f4bbb1 100644 --- a/docs/src/python/linalg.rst +++ b/docs/src/python/linalg.rst @@ -5,8 +5,8 @@ Linear Algebra .. currentmodule:: mlx.core.linalg -.. autosummary:: - :toctree: _autosummary +.. autosummary:: + :toctree: _autosummary inv tri_inv @@ -18,3 +18,7 @@ Linear Algebra svd eigvalsh eigh + lu + lu_factor + solve + solve_triangular diff --git a/mlx/backend/cpu/CMakeLists.txt b/mlx/backend/cpu/CMakeLists.txt index b98f3985c..17a6e053e 100644 --- a/mlx/backend/cpu/CMakeLists.txt +++ b/mlx/backend/cpu/CMakeLists.txt @@ -59,6 +59,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp ${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/luf.cpp ${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp ${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp ${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp diff --git a/mlx/backend/cpu/luf.cpp b/mlx/backend/cpu/luf.cpp new file mode 100644 index 000000000..4dbf313db --- /dev/null +++ b/mlx/backend/cpu/luf.cpp @@ -0,0 +1,88 @@ +// Copyright © 2024 Apple Inc. + +#include + +#include "mlx/allocator.h" +#include "mlx/backend/cpu/copy.h" +#include "mlx/backend/cpu/lapack.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +void lu_factor_impl( + const array& a, + array& lu, + array& pivots, + array& row_indices) { + int M = a.shape(-2); + int N = a.shape(-1); + + // Copy a into lu and make it col contiguous + auto ndim = lu.ndim(); + auto flags = lu.flags(); + flags.col_contiguous = ndim == 2; + flags.row_contiguous = false; + flags.contiguous = true; + auto strides = lu.strides(); + strides[ndim - 1] = M; + strides[ndim - 2] = 1; + lu.set_data( + allocator::malloc_or_wait(lu.nbytes()), lu.nbytes(), strides, flags); + copy_inplace( + a, lu, a.shape(), a.strides(), strides, 0, 0, CopyType::GeneralGeneral); + + auto a_ptr = lu.data(); + + pivots.set_data(allocator::malloc_or_wait(pivots.nbytes())); + row_indices.set_data(allocator::malloc_or_wait(row_indices.nbytes())); + auto pivots_ptr = pivots.data(); + auto row_indices_ptr = row_indices.data(); + + int info; + size_t num_matrices = a.size() / (M * N); + for (size_t i = 0; i < num_matrices; ++i) { + // Compute LU factorization of A + MLX_LAPACK_FUNC(sgetrf) + (/* m */ &M, + /* n */ &N, + /* a */ a_ptr, + /* lda */ &M, + /* ipiv */ reinterpret_cast(pivots_ptr), + /* info */ &info); + + if (info != 0) { + std::stringstream ss; + ss << "[LUF::eval_cpu] sgetrf_ failed with code " << info + << ((info > 0) ? " because matrix is singular" + : " because argument had an illegal value"); + throw std::runtime_error(ss.str()); + } + + // Subtract 1 to get 0-based index + for (int j = 0; j < pivots.shape(-1); ++j) { + pivots_ptr[j]--; + row_indices_ptr[j] = j; + } + for (int j = pivots.shape(-1) - 1; j >= 0; --j) { + auto piv = pivots_ptr[j]; + auto t1 = row_indices_ptr[piv]; + auto t2 = row_indices_ptr[j]; + row_indices_ptr[j] = t1; + row_indices_ptr[piv] = t2; + } + + // Advance pointers to the next matrix + a_ptr += M * N; + pivots_ptr += pivots.shape(-1); + row_indices_ptr += pivots.shape(-1); + } +} + +void LUF::eval_cpu( + const std::vector& inputs, + std::vector& outputs) { + assert(inputs.size() == 1); + lu_factor_impl(inputs[0], outputs[0], outputs[1], outputs[2]); +} + +} // namespace mlx::core diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 9aefd3f44..d07a66ee0 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -554,6 +554,12 @@ void Eigh::eval_gpu( throw std::runtime_error("[Eigvalsh::eval_gpu] Metal Eigh NYI."); } +void LUF::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + throw std::runtime_error("[LUF::eval_gpu] Metal LU factorization NYI."); +} + void View::eval_gpu(const std::vector& inputs, array& out) { auto& in = inputs[0]; auto ibytes = size_of(in.dtype()); diff --git a/mlx/backend/no_cpu/primitives.cpp b/mlx/backend/no_cpu/primitives.cpp index ce04cd600..d6d6f91ed 100644 --- a/mlx/backend/no_cpu/primitives.cpp +++ b/mlx/backend/no_cpu/primitives.cpp @@ -81,6 +81,7 @@ NO_CPU(LogicalNot) NO_CPU(LogicalAnd) NO_CPU(LogicalOr) NO_CPU(LogAddExp) +NO_CPU_MULTI(LUF) NO_CPU(Matmul) NO_CPU(Maximum) NO_CPU(Minimum) diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index f6d65ebe6..789a576db 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -81,6 +81,7 @@ NO_GPU(LogicalNot) NO_GPU(LogicalAnd) NO_GPU(LogicalOr) NO_GPU(LogAddExp) +NO_GPU_MULTI(LUF) NO_GPU(Matmul) NO_GPU(Maximum) NO_GPU(Minimum) diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index c4a21a881..e9a0d6e5a 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -282,7 +282,7 @@ array inv(const array& a, StreamOrDevice s /* = {} */) { array tri_inv( const array& a, - bool upper /* = true */, + bool upper /* = false */, StreamOrDevice s /* = {} */) { return inv_impl(a, /*tri=*/true, upper, s); } @@ -519,4 +519,134 @@ std::pair eigh( return std::make_pair(out[0], out[1]); } +void validate_lu( + const array& a, + const StreamOrDevice& stream, + const std::string& fname) { + check_cpu_stream(stream, fname); + if (a.dtype() != float32) { + std::ostringstream msg; + msg << fname << " Arrays must type float32. Received array " + << "with type " << a.dtype() << "."; + throw std::invalid_argument(msg.str()); + } + + if (a.ndim() < 2) { + std::ostringstream msg; + msg << fname + << " Arrays must have >= 2 dimensions. Received array " + "with " + << a.ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + + if (a.shape(-1) != a.shape(-2)) { + throw std::invalid_argument(fname + " Only defined for square matrices."); + } +} + +std::vector lu_helper(const array& a, StreamOrDevice s /* = {} */) { + int m = a.shape()[a.shape().size() - 2]; + int n = a.shape()[a.shape().size() - 1]; + + Shape pivots_shape(a.shape().begin(), a.shape().end() - 2); + pivots_shape.push_back(std::min(m, n)); + + return array::make_arrays( + {a.shape(), pivots_shape, pivots_shape}, + {a.dtype(), uint32, uint32}, + std::make_shared(to_stream(s)), + {astype(a, a.dtype(), s)}); +} + +std::vector lu(const array& a, StreamOrDevice s /* = {} */) { + validate_lu(a, s, "[linalg::lu]"); + + auto out = lu_helper(a, s); + auto& LU = out[0]; + auto& row_pivots = out[2]; + + int N = a.shape(-1); + auto L = add(tril(LU, /* k = */ -1, s), eye(N, s), s); + auto U = triu(LU, /* k = */ 0, s); + return {row_pivots, L, U}; +} + +std::pair lu_factor(const array& a, StreamOrDevice s /* = {} */) { + validate_lu(a, s, "[linalg::lu_factor]"); + auto out = lu_helper(a, s); + return std::make_pair(out[0], out[1]); +} + +void validate_solve( + const array& a, + const array& b, + const StreamOrDevice& stream, + const std::string& fname) { + check_cpu_stream(stream, fname); + if (a.ndim() < 2) { + std::ostringstream msg; + msg << fname << " First input must have >= 2 dimensions. " + << "Received array with " << a.ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + + if (b.ndim() < 1) { + std::ostringstream msg; + msg << fname << " Second input must have >= 1 dimensions. " + << "Received array with " << b.ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + + if (a.shape(-1) != a.shape(-2)) { + std::ostringstream msg; + msg << fname << " First input must be a square matrix. " + << "Received array with shape " << a.shape() << "."; + throw std::invalid_argument(msg.str()); + } + + int lastDim = b.ndim() > 1 ? -2 : -1; + if (a.shape(-1) != b.shape(lastDim)) { + std::ostringstream msg; + msg << fname << " Last dimension of first input with shape " << a.shape() + << " must match second to last dimension of" + << " second input with shape " << b.shape() << "."; + throw std::invalid_argument(msg.str()); + } + + auto out_type = promote_types(a.dtype(), b.dtype()); + if (out_type != float32) { + std::ostringstream msg; + msg << fname << " Input arrays must promote to float32. Received arrays " + << "with type " << a.dtype() << " and " << b.dtype() << "."; + throw std::invalid_argument(msg.str()); + } +} + +array solve(const array& a, const array& b, StreamOrDevice s /* = {} */) { + validate_solve(a, b, s, "[linalg::solve]"); + + // P, L, U matrices + const auto luf = lu(a, s); + auto perm = argsort(luf[0], -1, s); + int take_axis = -1; + if (b.ndim() >= 2) { + perm = expand_dims(perm, -1, s); + take_axis -= 1; + } + auto pb = take_along_axis(b, perm, take_axis); + auto y = solve_triangular(luf[1], pb, /* upper = */ false, s); + return solve_triangular(luf[2], y, /* upper = */ true, s); +} + +array solve_triangular( + const array& a, + const array& b, + bool upper /* = false */, + StreamOrDevice s /* = {} */) { + validate_solve(a, b, s, "[linalg::solve_triangular]"); + auto a_inv = tri_inv(a, upper, s); + return matmul(a_inv, b, s); +} + } // namespace mlx::core::linalg diff --git a/mlx/linalg.h b/mlx/linalg.h index 4ea81bef0..9fe4dbf60 100644 --- a/mlx/linalg.h +++ b/mlx/linalg.h @@ -74,6 +74,18 @@ array pinv(const array& a, StreamOrDevice s = {}); array cholesky_inv(const array& a, bool upper = false, StreamOrDevice s = {}); +std::vector lu(const array& a, StreamOrDevice s = {}); + +std::pair lu_factor(const array& a, StreamOrDevice s = {}); + +array solve(const array& a, const array& b, StreamOrDevice s = {}); + +array solve_triangular( + const array& a, + const array& b, + bool upper = false, + StreamOrDevice s = {}); + /** * Compute the cross product of two arrays along the given axis. */ diff --git a/mlx/primitives.h b/mlx/primitives.h index fed5b4988..19f058fd3 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -2329,7 +2329,6 @@ class Eigh : public Primitive { : Primitive(stream), uplo_(std::move(uplo)), compute_eigenvectors_(compute_eigenvectors) {} - void eval_cpu(const std::vector& inputs, std::vector& outputs) override; void eval_gpu(const std::vector& inputs, std::vector& outputs) @@ -2350,4 +2349,16 @@ class Eigh : public Primitive { bool compute_eigenvectors_; }; +/* LU Factorization primitive. */ +class LUF : public Primitive { + public: + explicit LUF(Stream stream) : Primitive(stream) {} + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_PRINT(LUF) +}; + } // namespace mlx::core diff --git a/python/src/linalg.cpp b/python/src/linalg.cpp index 2fd35d165..a43cebbe7 100644 --- a/python/src/linalg.cpp +++ b/python/src/linalg.cpp @@ -14,13 +14,6 @@ namespace mx = mlx::core; namespace nb = nanobind; using namespace nb::literals; -namespace { -nb::tuple svd_helper(const mx::array& a, mx::StreamOrDevice s /* = {} */) { - const auto result = mx::linalg::svd(a, s); - return nb::make_tuple(result.at(0), result.at(1), result.at(2)); -} -} // namespace - void init_linalg(nb::module_& parent_module) { auto m = parent_module.def_submodule( "linalg", "mlx.core.linalg: linear algebra routines."); @@ -213,7 +206,10 @@ void init_linalg(nb::module_& parent_module) { )pbdoc"); m.def( "svd", - &svd_helper, + [](const mx::array& a, mx::StreamOrDevice s /* = {} */) { + const auto result = mx::linalg::svd(a, s); + return nb::make_tuple(result.at(0), result.at(1), result.at(2)); + }, "a"_a, nb::kw_only(), "stream"_a = nb::none(), @@ -262,7 +258,7 @@ void init_linalg(nb::module_& parent_module) { "tri_inv", &mx::linalg::tri_inv, "a"_a, - "upper"_a, + "upper"_a = false, nb::kw_only(), "stream"_a = nb::none(), nb::sig( @@ -276,7 +272,7 @@ void init_linalg(nb::module_& parent_module) { Args: a (array): Input array. - upper (array): Whether the array is upper or lower triangular. Defaults to ``False``. + upper (bool, optional): Whether the array is upper or lower triangular. Defaults to ``False``. stream (Stream, optional): Stream or device. Defaults to ``None`` in which case the default stream of the default device is used. @@ -441,7 +437,6 @@ void init_linalg(nb::module_& parent_module) { m.def( "eigh", [](const mx::array& a, const std::string UPLO, mx::StreamOrDevice s) { - // TODO avoid cast? auto result = mx::linalg::eigh(a, UPLO, s); return nb::make_tuple(result.first, result.second); }, @@ -484,4 +479,102 @@ void init_linalg(nb::module_& parent_module) { array([[ 0.707107, -0.707107], [ 0.707107, 0.707107]], dtype=float32) )pbdoc"); + m.def( + "lu", + [](const mx::array& a, mx::StreamOrDevice s /* = {} */) { + auto result = mx::linalg::lu(a, s); + return nb::make_tuple(result.at(0), result.at(1), result.at(2)); + }, + "a"_a, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def lu(a: array, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array, array]"), + R"pbdoc( + Compute the LU factorization of the given matrix ``A``. + + Note, unlike the default behavior of ``scipy.linalg.lu``, the pivots + are indices. To reconstruct the input use ``L[P, :] @ U`` for 2 + dimensions or ``mx.take_along_axis(L, P[..., None], axis=-2) @ U`` + for more than 2 dimensions. + + To construct the full permuation matrix do: + + .. code-block:: + + P = mx.put_along_axis(mx.zeros_like(L), p[..., None], mx.array(1.0), axis=-1) + + Args: + a (array): Input array. + stream (Stream, optional): Stream or device. Defaults to ``None`` + in which case the default stream of the default device is used. + + Returns: + tuple(array, array, array): + The ``p``, ``L``, and ``U`` arrays, such that ``A = L[P, :] @ U`` + )pbdoc"); + m.def( + "lu_factor", + &mx::linalg::lu_factor, + "a"_a, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def lu_factor(a: array, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array]"), + R"pbdoc( + Computes a compact representation of the LU factorization. + + Args: + a (array): Input array. + stream (Stream, optional): Stream or device. Defaults to ``None`` + in which case the default stream of the default device is used. + + Returns: + tuple(array, array): The ``LU`` matrix and ``pivots`` array. + )pbdoc"); + m.def( + "solve", + &mx::linalg::solve, + "a"_a, + "b"_a, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def solve(a: array, b: array, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Compute the solution to a system of linear equations ``AX = B``. + + Args: + a (array): Input array. + b (array): Input array. + stream (Stream, optional): Stream or device. Defaults to ``None`` + in which case the default stream of the default device is used. + + Returns: + array: The unique solution to the system ``AX = B``. + )pbdoc"); + m.def( + "solve_triangular", + &mx::linalg::solve_triangular, + "a"_a, + "b"_a, + nb::kw_only(), + "upper"_a = false, + "stream"_a = nb::none(), + nb::sig( + "def solve_triangular(a: array, b: array, *, upper: bool = False, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Computes the solution of a triangular system of linear equations ``AX = B``. + + Args: + a (array): Input array. + b (array): Input array. + upper (bool, optional): Whether the array is upper or lower + triangular. Default: ``False``. + stream (Stream, optional): Stream or device. Defaults to ``None`` + in which case the default stream of the default device is used. + + Returns: + array: The unique solution to the system ``AX = B``. + )pbdoc"); } diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index c4ac69178..67e8d7bf9 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -330,6 +330,123 @@ class TestLinalg(mlx_tests.MLXTestCase): mx.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) ) # Non-square matrix + def test_lu(self): + with self.assertRaises(ValueError): + mx.linalg.lu(mx.array(0.0), stream=mx.cpu) + + with self.assertRaises(ValueError): + mx.linalg.lu(mx.array([0.0, 1.0]), stream=mx.cpu) + + with self.assertRaises(ValueError): + mx.linalg.lu(mx.array([[0, 1], [1, 0]]), stream=mx.cpu) + + # Test 3x3 matrix + a = mx.array([[3.0, 1.0, 2.0], [1.0, 8.0, 6.0], [9.0, 2.0, 5.0]]) + P, L, U = mx.linalg.lu(a, stream=mx.cpu) + self.assertTrue(mx.allclose(L[P, :] @ U, a)) + + # Test batch dimension + a = mx.broadcast_to(a, (5, 5, 3, 3)) + P, L, U = mx.linalg.lu(a, stream=mx.cpu) + L = mx.take_along_axis(L, P[..., None], axis=-2) + self.assertTrue(mx.allclose(L @ U, a)) + + def test_lu_factor(self): + mx.random.seed(7) + + # Test 3x3 matrix + a = mx.random.uniform(shape=(5, 5)) + LU, pivots = mx.linalg.lu_factor(a, stream=mx.cpu) + n = a.shape[-1] + + pivots = pivots.tolist() + perm = list(range(n)) + for i in range(len(pivots)): + perm[i], perm[pivots[i]] = perm[pivots[i]], perm[i] + + L = mx.add(mx.tril(LU, k=-1), mx.eye(n)) + U = mx.triu(LU) + self.assertTrue(mx.allclose(L @ U, a[perm, :])) + + def test_solve(self): + mx.random.seed(7) + + # Test 3x3 matrix with 1D rhs + a = mx.array([[3.0, 1.0, 2.0], [1.0, 8.0, 6.0], [9.0, 2.0, 5.0]]) + b = mx.array([11.0, 35.0, 28.0]) + + result = mx.linalg.solve(a, b, stream=mx.cpu) + expected = np.linalg.solve(a, b) + self.assertTrue(np.allclose(result, expected)) + + # Test symmetric positive-definite matrix + N = 5 + a = mx.random.uniform(shape=(N, N)) + a = mx.matmul(a, a.T) + N * mx.eye(N) + b = mx.random.uniform(shape=(N, 1)) + + result = mx.linalg.solve(a, b, stream=mx.cpu) + expected = np.linalg.solve(a, b) + self.assertTrue(np.allclose(result, expected)) + + # Test batch dimension + a = mx.random.uniform(shape=(5, 5, 4, 4)) + b = mx.random.uniform(shape=(5, 5, 4, 1)) + + result = mx.linalg.solve(a, b, stream=mx.cpu) + expected = np.linalg.solve(a, b) + self.assertTrue(np.allclose(result, expected, atol=1e-5)) + + # Test large matrix + N = 1000 + a = mx.random.uniform(shape=(N, N)) + b = mx.random.uniform(shape=(N, 1)) + + result = mx.linalg.solve(a, b, stream=mx.cpu) + expected = np.linalg.solve(a, b) + self.assertTrue(np.allclose(result, expected, atol=1e-3)) + + # Test multi-column rhs + a = mx.random.uniform(shape=(5, 5)) + b = mx.random.uniform(shape=(5, 8)) + + result = mx.linalg.solve(a, b, stream=mx.cpu) + expected = np.linalg.solve(a, b) + self.assertTrue(np.allclose(result, expected)) + + # Test batched multi-column rhs + a = mx.broadcast_to(a, (3, 2, 5, 5)) + b = mx.broadcast_to(b, (3, 1, 5, 8)) + + result = mx.linalg.solve(a, b, stream=mx.cpu) + expected = np.linalg.solve(a, b) + self.assertTrue(np.allclose(result, expected, rtol=1e-5, atol=1e-5)) + + def test_solve_triangular(self): + # Test lower triangular matrix + a = mx.array([[4.0, 0.0, 0.0], [2.0, 3.0, 0.0], [1.0, -2.0, 5.0]]) + b = mx.array([8.0, 14.0, 3.0]) + + result = mx.linalg.solve_triangular(a, b, upper=False, stream=mx.cpu) + expected = np.linalg.solve(a, b) + self.assertTrue(np.allclose(result, expected)) + + # Test upper triangular matrix + a = mx.array([[3.0, 2.0, 1.0], [0.0, 5.0, 4.0], [0.0, 0.0, 6.0]]) + b = mx.array([13.0, 33.0, 18.0]) + + result = mx.linalg.solve_triangular(a, b, upper=True, stream=mx.cpu) + expected = np.linalg.solve(a, b) + self.assertTrue(np.allclose(result, expected)) + + # Test batch multi-column rhs + a = mx.broadcast_to(a, (3, 4, 3, 3)) + b = mx.broadcast_to(mx.expand_dims(b, -1), (3, 4, 3, 8)) + + result = mx.linalg.solve_triangular(a, b, upper=True, stream=mx.cpu) + expected = np.linalg.solve(a, b) + self.assertTrue(np.allclose(result, expected)) + if __name__ == "__main__": unittest.main() diff --git a/tests/linalg_tests.cpp b/tests/linalg_tests.cpp index 352d16ff2..b2465c29a 100644 --- a/tests/linalg_tests.cpp +++ b/tests/linalg_tests.cpp @@ -465,3 +465,95 @@ TEST_CASE("test matrix eigh") { // Verify eigendecomposition CHECK(allclose(matmul(A, eigvecs), eigvals * eigvecs).item()); } + +TEST_CASE("test lu") { + // Test 2x2 matrix + array a = array({1., 2., 3., 4.}, {2, 2}); + auto out = linalg::lu(a, Device::cpu); + auto L = take_along_axis(out[1], expand_dims(out[0], -1), -2); + array expected = matmul(L, out[2]); + CHECK(allclose(a, expected).item()); + + // Test 3x3 matrix + a = array({1., 2., 3., 4., 5., 6., 7., 8., 10.}, {3, 3}); + out = linalg::lu(a, Device::cpu); + L = take_along_axis(out[1], expand_dims(out[0], -1), -2); + expected = matmul(L, out[2]); + CHECK(allclose(a, expected).item()); + + // Test batch dimension + a = broadcast_to(a, {3, 3, 3}); + out = linalg::lu(a, Device::cpu); + L = take_along_axis(out[1], expand_dims(out[0], -1), -2); + expected = matmul(L, out[2]); + CHECK(allclose(a, expected).item()); +} + +TEST_CASE("test solve") { + // 0D and 1D throw + CHECK_THROWS(linalg::solve(array(0.), array(0.), Device::cpu)); + CHECK_THROWS(linalg::solve(array({0.}), array({0.}), Device::cpu)); + + // Unsupported types throw + CHECK_THROWS( + linalg::solve(array({0, 1, 1, 2}, {2, 2}), array({1, 3}), Device::cpu)); + + // Non-square throws + array a = reshape(arange(6), {3, 2}); + array b = reshape(arange(3), {3, 1}); + CHECK_THROWS(linalg::solve(a, b, Device::cpu)); + + // Test 2x2 matrix with 1D rhs + a = array({2., 1., 1., 3.}, {2, 2}); + b = array({8., 13.}, {2}); + + array result = linalg::solve(a, b, Device::cpu); + CHECK(allclose(matmul(a, result), b).item()); + + // Test 3x3 matrix + a = array({1., 2., 3., 4., 5., 6., 7., 8., 10.}, {3, 3}); + b = array({6., 15., 25.}, {3, 1}); + + result = linalg::solve(a, b, Device::cpu); + CHECK(allclose(matmul(a, result), b).item()); + + // Test batch dimension + a = broadcast_to(a, {5, 3, 3}); + b = broadcast_to(b, {5, 3, 1}); + + result = linalg::solve(a, b, Device::cpu); + CHECK(allclose(matmul(a, result), b).item()); + + // Test multi-column rhs + a = array({2., 1., 1., 1., 3., 2., 1., 0., 0.}, {3, 3}); + b = array({4., 2., 5., 3., 6., 1.}, {3, 2}); + + result = linalg::solve(a, b, Device::cpu); + CHECK(allclose(matmul(a, result), b).item()); + + // Test batch multi-column rhs + a = broadcast_to(a, {5, 3, 3}); + b = broadcast_to(b, {5, 3, 2}); + + result = linalg::solve(a, b, Device::cpu); + CHECK(allclose(matmul(a, result), b).item()); +} + +TEST_CASE("test solve_triangluar") { + // Test lower triangular matrix + array a = array({2., 0., 0., 3., 1., 0., 1., -1., 1.}, {3, 3}); + array b = array({2., 5., 0.}); + + array result = + linalg::solve_triangular(a, b, /* upper = */ false, Device::cpu); + array expected = array({1., 2., 1.}); + CHECK(allclose(expected, result).item()); + + // Test upper triangular matrix + a = array({2., 1., 3., 0., 4., 2., 0., 0., 1.}, {3, 3}); + b = array({5., 14., 3.}); + + result = linalg::solve_triangular(a, b, /* upper = */ true, Device::cpu); + expected = array({-3., 2., 3.}); + CHECK(allclose(expected, result).item()); +}