Add vmap to scatter (#1200)

* Add vmap to scatter

* updates

* vmap updates + a few more tests

* bug fix

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
nicolov
2024-08-06 05:12:27 +02:00
committed by GitHub
parent 58d0e199e1
commit 8c9f0278b9
6 changed files with 269 additions and 8 deletions

View File

@@ -563,8 +563,10 @@ std::pair<std::vector<array>, std::vector<array>> vmap_trace(
detail::InTracing in_tracing;
if (in_axes.size() != inputs.size()) {
throw std::invalid_argument(
"[vmap] The number of in axes must match the number of inputs.");
std::stringstream ss;
ss << "[vmap] The number of in axes (" << in_axes.size()
<< ") must match the number of inputs (" << inputs.size() << ").";
throw std::invalid_argument(ss.str());
}
// Some error checking and get the vmap axis size
@@ -620,8 +622,10 @@ std::vector<array> vmap_replace(
const std::vector<int>& in_axes,
const std::vector<int>& out_axes) {
if (out_axes.size() != s_outputs.size()) {
throw std::invalid_argument(
"[vmap] The number of out axes must match the number of outputs.");
std::stringstream msg;
msg << "[vmap] The number of out axes (" << out_axes.size()
<< ") must match the number of outputs (" << s_outputs.size() << ").";
throw std::invalid_argument(msg.str());
}
std::unordered_map<std::uintptr_t, std::pair<array, int>> tmap;