Add vmap to scatter (#1200)

* Add vmap to scatter

* updates

* vmap updates + a few more tests

* bug fix

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
nicolov
2024-08-06 05:12:27 +02:00
committed by GitHub
parent 58d0e199e1
commit 8c9f0278b9
6 changed files with 269 additions and 8 deletions

View File

@@ -1713,7 +1713,9 @@ class Scatter : public UnaryPrimitive {
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP();
DEFINE_GRADS();
void print(std::ostream& os) override {
os << "Scatter";
switch (reduce_type_) {