Fix Split::vmap (#1845)

This commit is contained in:
Angelos Katharopoulos
2025-02-08 09:22:13 -08:00
committed by GitHub
parent 1c0c118f7c
commit 9eb7d7362f
2 changed files with 15 additions and 1 deletions

View File

@@ -4273,7 +4273,9 @@ std::pair<std::vector<array>, std::vector<int>> Split::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
int axis_left = axes[0] >= 0 && axes[0] <= axis_;
return {{split(inputs[0], indices_, axis_ + axis_left, stream())}, axes};
auto output = split(inputs[0], indices_, axis_ + axis_left, stream());
std::vector<int> output_axes(output.size(), axes[0]);
return {std::move(output), std::move(output_axes)};
}
std::vector<array> Split::vjp(