Scatter vjp (#394)

* Add a first scatter vjp
* Implement the scatter_add vjp
* Add array.at to implement user friendly scatters
This commit is contained in:
Angelos Katharopoulos
2024-01-09 13:36:51 -08:00
committed by GitHub
parent e9ca65c939
commit 961435a243
7 changed files with 360 additions and 33 deletions

View File

@@ -3310,7 +3310,6 @@ void init_ops(py::module_& m) {
Returns:
result (array): The tensor dot product.
)pbdoc");
m.def(
"inner",
&inner,
@@ -3331,7 +3330,6 @@ void init_ops(py::module_& m) {
Returns:
result (array): The inner product.
)pbdoc");
m.def(
"outer",
&outer,