mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 01:50:16 +08:00
Fix concatenate vmap (#1600)
This commit is contained in:

committed by
GitHub

parent
2af7e8a9a6
commit
5e89aace9b
@@ -836,31 +836,43 @@ std::vector<array> Concatenate::jvp(
|
||||
std::pair<std::vector<array>, std::vector<int>> Concatenate::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
std::vector<array> t_inputs;
|
||||
int out_ax = -1;
|
||||
int first_vmap = -1;
|
||||
|
||||
// Find the first vmapped input
|
||||
int i = 0;
|
||||
for (; i < axes.size(); i++) {
|
||||
t_inputs.push_back(inputs[i]);
|
||||
for (int i = 0; i < axes.size(); i++) {
|
||||
if (axes[i] >= 0) {
|
||||
out_ax = axes[i];
|
||||
first_vmap = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (out_ax >= 0) {
|
||||
// Advance to the next input
|
||||
i++;
|
||||
|
||||
// Move vmap axes to the same spot.
|
||||
for (; i < axes.size(); ++i) {
|
||||
if (out_ax != axes[i] && axes[i] >= 0) {
|
||||
// No vmap, should we even be in here?
|
||||
if (out_ax < 0) {
|
||||
return {{concatenate(inputs, axis_, stream())}, {out_ax}};
|
||||
}
|
||||
|
||||
// Make sure vmapped arrays have all vmapped axes in the same location and
|
||||
// expand non-vmapped arrays to be compatible with the vmapped ones.
|
||||
std::vector<array> t_inputs;
|
||||
int N = inputs[first_vmap].shape(out_ax);
|
||||
int axis = axis_ + (axis_ >= out_ax);
|
||||
auto cat_shape = inputs[first_vmap].shape();
|
||||
for (int i = 0; i < axes.size(); i++) {
|
||||
if (axes[i] >= 0) {
|
||||
if (out_ax != axes[i]) {
|
||||
t_inputs.push_back(moveaxis(inputs[i], axes[i], out_ax, stream()));
|
||||
} else {
|
||||
t_inputs.push_back(inputs[i]);
|
||||
}
|
||||
} else {
|
||||
cat_shape[axis] = inputs[i].shape(axis_);
|
||||
t_inputs.push_back(broadcast_to(
|
||||
expand_dims(inputs[i], out_ax, stream()), cat_shape, stream()));
|
||||
}
|
||||
}
|
||||
auto axis = axis_ + (out_ax >= 0 && axis_ >= out_ax);
|
||||
|
||||
return {{concatenate(t_inputs, axis, stream())}, {out_ax}};
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user