Adds nuclear norm support (#1894)

* adjust norm unit test tolerance
This commit is contained in:
Abe Leininger
2025-03-04 15:26:02 -06:00
committed by GitHub
parent 9680f72cca
commit 3835a428c5
11 changed files with 260 additions and 55 deletions

View File

@@ -4940,7 +4940,8 @@ std::pair<std::vector<array>, std::vector<int>> SVD::vmap(
const std::vector<int>& axes) {
auto ax = axes[0] >= 0 ? 0 : -1;
auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0];
return {{linalg::svd(a, stream())}, {ax, ax, ax}};
std::vector<int> new_axes(compute_uv_ ? 3 : 1, ax);
return {linalg::svd(a, compute_uv_, stream()), std::move(new_axes)};
}
std::pair<std::vector<array>, std::vector<int>> Inverse::vmap(