mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
@@ -4940,7 +4940,8 @@ std::pair<std::vector<array>, std::vector<int>> SVD::vmap(
|
||||
const std::vector<int>& axes) {
|
||||
auto ax = axes[0] >= 0 ? 0 : -1;
|
||||
auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0];
|
||||
return {{linalg::svd(a, stream())}, {ax, ax, ax}};
|
||||
std::vector<int> new_axes(compute_uv_ ? 3 : 1, ax);
|
||||
return {linalg::svd(a, compute_uv_, stream()), std::move(new_axes)};
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Inverse::vmap(
|
||||
|
||||
Reference in New Issue
Block a user