mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-24 12:18:20 +08:00
Allow non-square lu (#1889)
This commit is contained in:
@@ -59,10 +59,14 @@ void lu_factor_impl(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Subtract 1 to get 0-based index
|
// Subtract 1 to get 0-based index
|
||||||
for (int j = 0; j < pivots.shape(-1); ++j) {
|
int j = 0;
|
||||||
|
for (; j < pivots.shape(-1); ++j) {
|
||||||
pivots_ptr[j]--;
|
pivots_ptr[j]--;
|
||||||
row_indices_ptr[j] = j;
|
row_indices_ptr[j] = j;
|
||||||
}
|
}
|
||||||
|
for (; j < row_indices.shape(-1); ++j) {
|
||||||
|
row_indices_ptr[j] = j;
|
||||||
|
}
|
||||||
for (int j = pivots.shape(-1) - 1; j >= 0; --j) {
|
for (int j = pivots.shape(-1) - 1; j >= 0; --j) {
|
||||||
auto piv = pivots_ptr[j];
|
auto piv = pivots_ptr[j];
|
||||||
auto t1 = row_indices_ptr[piv];
|
auto t1 = row_indices_ptr[piv];
|
||||||
|
@@ -539,10 +539,6 @@ void validate_lu(
|
|||||||
<< a.ndim() << " dimensions.";
|
<< a.ndim() << " dimensions.";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (a.shape(-1) != a.shape(-2)) {
|
|
||||||
throw std::invalid_argument(fname + " Only defined for square matrices.");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<array> lu_helper(const array& a, StreamOrDevice s /* = {} */) {
|
std::vector<array> lu_helper(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
@@ -552,8 +548,10 @@ std::vector<array> lu_helper(const array& a, StreamOrDevice s /* = {} */) {
|
|||||||
Shape pivots_shape(a.shape().begin(), a.shape().end() - 2);
|
Shape pivots_shape(a.shape().begin(), a.shape().end() - 2);
|
||||||
pivots_shape.push_back(std::min(m, n));
|
pivots_shape.push_back(std::min(m, n));
|
||||||
|
|
||||||
|
Shape row_idx_shape(a.shape().begin(), a.shape().end() - 1);
|
||||||
|
|
||||||
return array::make_arrays(
|
return array::make_arrays(
|
||||||
{a.shape(), pivots_shape, pivots_shape},
|
{a.shape(), pivots_shape, row_idx_shape},
|
||||||
{a.dtype(), uint32, uint32},
|
{a.dtype(), uint32, uint32},
|
||||||
std::make_shared<LUF>(to_stream(s)),
|
std::make_shared<LUF>(to_stream(s)),
|
||||||
{astype(a, a.dtype(), s)});
|
{astype(a, a.dtype(), s)});
|
||||||
@@ -565,10 +563,24 @@ std::vector<array> lu(const array& a, StreamOrDevice s /* = {} */) {
|
|||||||
auto out = lu_helper(a, s);
|
auto out = lu_helper(a, s);
|
||||||
auto& LU = out[0];
|
auto& LU = out[0];
|
||||||
auto& row_pivots = out[2];
|
auto& row_pivots = out[2];
|
||||||
|
auto L = tril(LU, /* k = */ -1, s);
|
||||||
int N = a.shape(-1);
|
|
||||||
auto L = add(tril(LU, /* k = */ -1, s), eye(N, s), s);
|
|
||||||
auto U = triu(LU, /* k = */ 0, s);
|
auto U = triu(LU, /* k = */ 0, s);
|
||||||
|
|
||||||
|
int M = a.shape(-2);
|
||||||
|
int N = a.shape(-1);
|
||||||
|
int K = std::min(M, N);
|
||||||
|
if (N != K) {
|
||||||
|
auto start = Shape(L.ndim(), 0);
|
||||||
|
auto stop = L.shape();
|
||||||
|
stop.back() = K;
|
||||||
|
L = slice(L, std::move(start), std::move(stop), s);
|
||||||
|
} else if (M != K) {
|
||||||
|
auto start = Shape(U.ndim(), 0);
|
||||||
|
auto stop = U.shape();
|
||||||
|
stop[U.ndim() - 2] = K;
|
||||||
|
U = slice(U, std::move(start), std::move(stop), s);
|
||||||
|
}
|
||||||
|
L = add(L, eye(M, K, s), s);
|
||||||
return {row_pivots, L, U};
|
return {row_pivots, L, U};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -358,6 +358,15 @@ class TestLinalg(mlx_tests.MLXTestCase):
|
|||||||
L = mx.take_along_axis(L, P[..., None], axis=-2)
|
L = mx.take_along_axis(L, P[..., None], axis=-2)
|
||||||
self.assertTrue(mx.allclose(L @ U, a))
|
self.assertTrue(mx.allclose(L @ U, a))
|
||||||
|
|
||||||
|
# Test non-square matrix
|
||||||
|
a = mx.array([[3.0, 1.0, 2.0], [1.0, 8.0, 6.0]])
|
||||||
|
P, L, U = mx.linalg.lu(a, stream=mx.cpu)
|
||||||
|
self.assertTrue(mx.allclose(L[P, :] @ U, a))
|
||||||
|
|
||||||
|
a = mx.array([[3.0, 1.0], [1.0, 8.0], [9.0, 2.0]])
|
||||||
|
P, L, U = mx.linalg.lu(a, stream=mx.cpu)
|
||||||
|
self.assertTrue(mx.allclose(L[P, :] @ U, a))
|
||||||
|
|
||||||
def test_lu_factor(self):
|
def test_lu_factor(self):
|
||||||
mx.random.seed(7)
|
mx.random.seed(7)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user