Scatter vjp (#394)

* Add a first scatter vjp
* Implement the scatter_add vjp
* Add array.at to implement user friendly scatters
This commit is contained in:
Angelos Katharopoulos
2024-01-09 13:36:51 -08:00
committed by GitHub
parent e9ca65c939
commit 961435a243
7 changed files with 360 additions and 33 deletions

View File

@@ -462,6 +462,37 @@ array create_array(array_init_type v, std::optional<Dtype> t) {
}
}
class ArrayAt {
public:
ArrayAt(array x) : x_(std::move(x)) {}
ArrayAt& set_indices(py::object indices) {
indices_ = indices;
return *this;
}
array add(const ScalarOrArray& v) {
return mlx_add_item(x_, indices_, v);
}
array subtract(const ScalarOrArray& v) {
return mlx_subtract_item(x_, indices_, v);
}
array multiply(const ScalarOrArray& v) {
return mlx_multiply_item(x_, indices_, v);
}
array divide(const ScalarOrArray& v) {
return mlx_divide_item(x_, indices_, v);
}
array maximum(const ScalarOrArray& v) {
return mlx_maximum_item(x_, indices_, v);
}
array minimum(const ScalarOrArray& v) {
return mlx_minimum_item(x_, indices_, v);
}
private:
array x_;
py::object indices_;
};
void init_array(py::module_& m) {
// Types
py::class_<Dtype>(
@@ -501,6 +532,26 @@ void init_array(py::module_& m) {
m.attr("bfloat16") = py::cast(bfloat16);
m.attr("complex64") = py::cast(complex64);
py::class_<ArrayAt>(
m,
"_ArrayAt",
R"pbdoc(
A helper object to apply updates at specific indices.
)pbdoc")
.def(
py::init([](const array& x) { return ArrayAt(x); }),
"x"_a,
R"pbdoc(
__init__(self, x: array)
)pbdoc")
.def("__getitem__", &ArrayAt::set_indices, "indices"_a)
.def("add", &ArrayAt::add, "value"_a)
.def("subtract", &ArrayAt::subtract, "value"_a)
.def("multiply", &ArrayAt::multiply, "value"_a)
.def("divide", &ArrayAt::divide, "value"_a)
.def("maximum", &ArrayAt::maximum, "value"_a)
.def("minimum", &ArrayAt::minimum, "value"_a);
auto array_class = py::class_<array>(
m,
"array",
@@ -610,6 +661,38 @@ void init_array(py::module_& m) {
)pbdoc")
.def("__getitem__", mlx_get_item)
.def("__setitem__", mlx_set_item)
.def_property_readonly(
"at",
[](const array& a) { return ArrayAt(a); },
R"pbdoc(
Used to apply updates at the given indices.
.. note::
Python in place updates for all array frameworks map to
assignment. For instance ``x[idx] += y`` maps to ``x[idx] =
x[idx] + y``. As a result, assigning to the same index ignores
all but one updates. Using ``x.at[idx].add(y)`` will correctly
apply all the updates to all indices.
.. list-table::
:header-rows: 1
* - array.at syntax
- In-place syntax
* - ``x = x.at[idx].add(y)``
- ``x[idx] += y``
* - ``x = x.at[idx].subtract(y)``
- ``x[idx] -= y``
* - ``x = x.at[idx].multiply(y)``
- ``x[idx] *= y``
* - ``x = x.at[idx].divide(y)``
- ``x[idx] /= y``
* - ``x = x.at[idx].maximum(y)``
- ``x[idx] = mx.maximum(x[idx], y)``
* - ``x = x.at[idx].minimum(y)``
- ``x[idx] = mx.minimum(x[idx], y)``
)pbdoc")
.def(
"__len__",
[](const array& a) {