Fix vmap constant output size (#1524)

* use inputs to determine output size

* remove noop vmap tests
This commit is contained in:
Alex Barron
2024-10-30 16:16:53 -07:00
committed by GitHub
parent 917252a5a1
commit 048fabdabd
3 changed files with 38 additions and 36 deletions

View File

@@ -686,6 +686,17 @@ std::vector<array> vmap_replace(
throw std::invalid_argument(msg.str());
}
int vmap_size = -1;
for (int i = 0; i < inputs.size(); ++i) {
if (in_axes[i] >= 0) {
vmap_size = inputs[i].shape(in_axes[i]);
break;
}
}
if (vmap_size == -1) {
throw std::invalid_argument("At least one of in_axes must be non-None.");
}
std::unordered_map<std::uintptr_t, std::pair<array, int>> tmap;
std::unordered_set<std::uintptr_t> needs_vmap;
std::unordered_set<std::uintptr_t> cache;
@@ -782,7 +793,11 @@ std::vector<array> vmap_replace(
}
outputs.push_back(out);
} else {
outputs.push_back(s_outputs[i]);
// When the output has no input dependencies
// use the size of the vmapped axis in the inputs to expand the output
array output = expand_dims(s_outputs[i], out_axes[i]);
output = repeat(output, vmap_size, out_axes[i]);
outputs.push_back(output);
}
}
return outputs;