diff --git a/mlx/backend/cpu/luf.cpp b/mlx/backend/cpu/luf.cpp index 4dbf313db..e055f4cac 100644 --- a/mlx/backend/cpu/luf.cpp +++ b/mlx/backend/cpu/luf.cpp @@ -59,10 +59,14 @@ void lu_factor_impl( } // Subtract 1 to get 0-based index - for (int j = 0; j < pivots.shape(-1); ++j) { + int j = 0; + for (; j < pivots.shape(-1); ++j) { pivots_ptr[j]--; row_indices_ptr[j] = j; } + for (; j < row_indices.shape(-1); ++j) { + row_indices_ptr[j] = j; + } for (int j = pivots.shape(-1) - 1; j >= 0; --j) { auto piv = pivots_ptr[j]; auto t1 = row_indices_ptr[piv]; diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index e9a0d6e5a..01aa9b7ff 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -539,10 +539,6 @@ void validate_lu( << a.ndim() << " dimensions."; throw std::invalid_argument(msg.str()); } - - if (a.shape(-1) != a.shape(-2)) { - throw std::invalid_argument(fname + " Only defined for square matrices."); - } } std::vector lu_helper(const array& a, StreamOrDevice s /* = {} */) { @@ -552,8 +548,10 @@ std::vector lu_helper(const array& a, StreamOrDevice s /* = {} */) { Shape pivots_shape(a.shape().begin(), a.shape().end() - 2); pivots_shape.push_back(std::min(m, n)); + Shape row_idx_shape(a.shape().begin(), a.shape().end() - 1); + return array::make_arrays( - {a.shape(), pivots_shape, pivots_shape}, + {a.shape(), pivots_shape, row_idx_shape}, {a.dtype(), uint32, uint32}, std::make_shared(to_stream(s)), {astype(a, a.dtype(), s)}); @@ -565,10 +563,24 @@ std::vector lu(const array& a, StreamOrDevice s /* = {} */) { auto out = lu_helper(a, s); auto& LU = out[0]; auto& row_pivots = out[2]; - - int N = a.shape(-1); - auto L = add(tril(LU, /* k = */ -1, s), eye(N, s), s); + auto L = tril(LU, /* k = */ -1, s); auto U = triu(LU, /* k = */ 0, s); + + int M = a.shape(-2); + int N = a.shape(-1); + int K = std::min(M, N); + if (N != K) { + auto start = Shape(L.ndim(), 0); + auto stop = L.shape(); + stop.back() = K; + L = slice(L, std::move(start), std::move(stop), s); + } else if (M != K) { + auto start = Shape(U.ndim(), 0); + auto stop = U.shape(); + stop[U.ndim() - 2] = K; + U = slice(U, std::move(start), std::move(stop), s); + } + L = add(L, eye(M, K, s), s); return {row_pivots, L, U}; } diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index bae3dc17a..adc365c62 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -358,6 +358,15 @@ class TestLinalg(mlx_tests.MLXTestCase): L = mx.take_along_axis(L, P[..., None], axis=-2) self.assertTrue(mx.allclose(L @ U, a)) + # Test non-square matrix + a = mx.array([[3.0, 1.0, 2.0], [1.0, 8.0, 6.0]]) + P, L, U = mx.linalg.lu(a, stream=mx.cpu) + self.assertTrue(mx.allclose(L[P, :] @ U, a)) + + a = mx.array([[3.0, 1.0], [1.0, 8.0], [9.0, 2.0]]) + P, L, U = mx.linalg.lu(a, stream=mx.cpu) + self.assertTrue(mx.allclose(L[P, :] @ U, a)) + def test_lu_factor(self): mx.random.seed(7)