Fix concatenate vmap (#1600)

This commit is contained in:
Angelos Katharopoulos
2024-11-19 10:44:04 -08:00
committed by GitHub
parent 2af7e8a9a6
commit 5e89aace9b
2 changed files with 43 additions and 11 deletions

View File

@@ -836,31 +836,43 @@ std::vector<array> Concatenate::jvp(
std::pair<std::vector<array>, std::vector<int>> Concatenate::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
std::vector<array> t_inputs;
int out_ax = -1;
int first_vmap = -1;
// Find the first vmapped input
int i = 0;
for (; i < axes.size(); i++) {
t_inputs.push_back(inputs[i]);
for (int i = 0; i < axes.size(); i++) {
if (axes[i] >= 0) {
out_ax = axes[i];
first_vmap = i;
break;
}
}
if (out_ax >= 0) {
// Advance to the next input
i++;
// Move vmap axes to the same spot.
for (; i < axes.size(); ++i) {
if (out_ax != axes[i] && axes[i] >= 0) {
// No vmap, should we even be in here?
if (out_ax < 0) {
return {{concatenate(inputs, axis_, stream())}, {out_ax}};
}
// Make sure vmapped arrays have all vmapped axes in the same location and
// expand non-vmapped arrays to be compatible with the vmapped ones.
std::vector<array> t_inputs;
int N = inputs[first_vmap].shape(out_ax);
int axis = axis_ + (axis_ >= out_ax);
auto cat_shape = inputs[first_vmap].shape();
for (int i = 0; i < axes.size(); i++) {
if (axes[i] >= 0) {
if (out_ax != axes[i]) {
t_inputs.push_back(moveaxis(inputs[i], axes[i], out_ax, stream()));
} else {
t_inputs.push_back(inputs[i]);
}
} else {
cat_shape[axis] = inputs[i].shape(axis_);
t_inputs.push_back(broadcast_to(
expand_dims(inputs[i], out_ax, stream()), cat_shape, stream()));
}
}
auto axis = axis_ + (out_ax >= 0 && axis_ >= out_ax);
return {{concatenate(t_inputs, axis, stream())}, {out_ax}};
}