Scatter vjp (#394)

* Add a first scatter vjp
* Implement the scatter_add vjp
* Add array.at to implement user friendly scatters
This commit is contained in:
Angelos Katharopoulos
2024-01-09 13:36:51 -08:00
committed by GitHub
parent e9ca65c939
commit 961435a243
7 changed files with 360 additions and 33 deletions

View File

@@ -12,3 +12,27 @@ using namespace mlx::core;
array mlx_get_item(const array& src, const py::object& obj);
void mlx_set_item(array& src, const py::object& obj, const ScalarOrArray& v);
array mlx_add_item(
const array& src,
const py::object& obj,
const ScalarOrArray& v);
array mlx_subtract_item(
const array& src,
const py::object& obj,
const ScalarOrArray& v);
array mlx_multiply_item(
const array& src,
const py::object& obj,
const ScalarOrArray& v);
array mlx_divide_item(
const array& src,
const py::object& obj,
const ScalarOrArray& v);
array mlx_maximum_item(
const array& src,
const py::object& obj,
const ScalarOrArray& v);
array mlx_minimum_item(
const array& src,
const py::object& obj,
const ScalarOrArray& v);