Allow non-square lu (#1889)

This commit is contained in:
Awni Hannun
2025-02-20 08:13:23 -08:00
committed by GitHub
parent c86422bdd4
commit bbda0fdbdb
3 changed files with 34 additions and 9 deletions

View File

@@ -59,10 +59,14 @@ void lu_factor_impl(
}
// Subtract 1 to get 0-based index
for (int j = 0; j < pivots.shape(-1); ++j) {
int j = 0;
for (; j < pivots.shape(-1); ++j) {
pivots_ptr[j]--;
row_indices_ptr[j] = j;
}
for (; j < row_indices.shape(-1); ++j) {
row_indices_ptr[j] = j;
}
for (int j = pivots.shape(-1) - 1; j >= 0; --j) {
auto piv = pivots_ptr[j];
auto t1 = row_indices_ptr[piv];