Add vmap to scatter (#1200)

* Add vmap to scatter

* updates

* vmap updates + a few more tests

* bug fix

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
nicolov
2024-08-06 05:12:27 +02:00
committed by GitHub
parent 58d0e199e1
commit 8c9f0278b9
6 changed files with 269 additions and 8 deletions

View File

@@ -1103,6 +1103,9 @@ array moveaxis(
};
source = check_ax(source);
destination = check_ax(destination);
if (source == destination) {
return a;
}
std::vector<int> reorder(a.ndim());
std::iota(reorder.begin(), reorder.end(), 0);
reorder.erase(reorder.begin() + source);
@@ -2715,9 +2718,8 @@ array scatter(
if (updates.ndim() != (a.ndim() + idx_shape.size())) {
std::ostringstream msg;
msg << "[scatter] Updates with " << updates.ndim()
<< " dimensions does not match the sum of the array and indices "
"dimensions "
<< a.ndim() + idx_shape.size() << ".";
<< " dimensions does not match the sum of the array (" << a.ndim()
<< ") and indices (" << idx_shape.size() << ") dimensions.";
throw std::invalid_argument(msg.str());
}
for (int i = 0; i < idx_shape.size(); ++i) {
@@ -2759,11 +2761,12 @@ array scatter(
inputs.insert(inputs.begin(), a);
// TODO promote or cast?
inputs.push_back(astype(updates, a.dtype(), s));
return array(
a.shape(),
a.dtype(),
std::make_shared<Scatter>(to_stream(s), mode, axes),
inputs);
std::move(inputs));
}
array scatter(