mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix Split::vmap (#1845)
This commit is contained in:
committed by
GitHub
parent
1c0c118f7c
commit
9eb7d7362f
@@ -4273,7 +4273,9 @@ std::pair<std::vector<array>, std::vector<int>> Split::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
int axis_left = axes[0] >= 0 && axes[0] <= axis_;
|
||||
return {{split(inputs[0], indices_, axis_ + axis_left, stream())}, axes};
|
||||
auto output = split(inputs[0], indices_, axis_ + axis_left, stream());
|
||||
std::vector<int> output_axes(output.size(), axes[0]);
|
||||
return {std::move(output), std::move(output_axes)};
|
||||
}
|
||||
|
||||
std::vector<array> Split::vjp(
|
||||
|
||||
Reference in New Issue
Block a user