mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-14 17:12:49 +08:00
Scatter vjp (#394)
* Add a first scatter vjp * Implement the scatter_add vjp * Add array.at to implement user friendly scatters
This commit is contained in:
committed by
GitHub
parent
e9ca65c939
commit
961435a243
@@ -462,6 +462,37 @@ array create_array(array_init_type v, std::optional<Dtype> t) {
|
||||
}
|
||||
}
|
||||
|
||||
class ArrayAt {
|
||||
public:
|
||||
ArrayAt(array x) : x_(std::move(x)) {}
|
||||
ArrayAt& set_indices(py::object indices) {
|
||||
indices_ = indices;
|
||||
return *this;
|
||||
}
|
||||
array add(const ScalarOrArray& v) {
|
||||
return mlx_add_item(x_, indices_, v);
|
||||
}
|
||||
array subtract(const ScalarOrArray& v) {
|
||||
return mlx_subtract_item(x_, indices_, v);
|
||||
}
|
||||
array multiply(const ScalarOrArray& v) {
|
||||
return mlx_multiply_item(x_, indices_, v);
|
||||
}
|
||||
array divide(const ScalarOrArray& v) {
|
||||
return mlx_divide_item(x_, indices_, v);
|
||||
}
|
||||
array maximum(const ScalarOrArray& v) {
|
||||
return mlx_maximum_item(x_, indices_, v);
|
||||
}
|
||||
array minimum(const ScalarOrArray& v) {
|
||||
return mlx_minimum_item(x_, indices_, v);
|
||||
}
|
||||
|
||||
private:
|
||||
array x_;
|
||||
py::object indices_;
|
||||
};
|
||||
|
||||
void init_array(py::module_& m) {
|
||||
// Types
|
||||
py::class_<Dtype>(
|
||||
@@ -501,6 +532,26 @@ void init_array(py::module_& m) {
|
||||
m.attr("bfloat16") = py::cast(bfloat16);
|
||||
m.attr("complex64") = py::cast(complex64);
|
||||
|
||||
py::class_<ArrayAt>(
|
||||
m,
|
||||
"_ArrayAt",
|
||||
R"pbdoc(
|
||||
A helper object to apply updates at specific indices.
|
||||
)pbdoc")
|
||||
.def(
|
||||
py::init([](const array& x) { return ArrayAt(x); }),
|
||||
"x"_a,
|
||||
R"pbdoc(
|
||||
__init__(self, x: array)
|
||||
)pbdoc")
|
||||
.def("__getitem__", &ArrayAt::set_indices, "indices"_a)
|
||||
.def("add", &ArrayAt::add, "value"_a)
|
||||
.def("subtract", &ArrayAt::subtract, "value"_a)
|
||||
.def("multiply", &ArrayAt::multiply, "value"_a)
|
||||
.def("divide", &ArrayAt::divide, "value"_a)
|
||||
.def("maximum", &ArrayAt::maximum, "value"_a)
|
||||
.def("minimum", &ArrayAt::minimum, "value"_a);
|
||||
|
||||
auto array_class = py::class_<array>(
|
||||
m,
|
||||
"array",
|
||||
@@ -610,6 +661,38 @@ void init_array(py::module_& m) {
|
||||
)pbdoc")
|
||||
.def("__getitem__", mlx_get_item)
|
||||
.def("__setitem__", mlx_set_item)
|
||||
.def_property_readonly(
|
||||
"at",
|
||||
[](const array& a) { return ArrayAt(a); },
|
||||
R"pbdoc(
|
||||
Used to apply updates at the given indices.
|
||||
|
||||
.. note::
|
||||
|
||||
Python in place updates for all array frameworks map to
|
||||
assignment. For instance ``x[idx] += y`` maps to ``x[idx] =
|
||||
x[idx] + y``. As a result, assigning to the same index ignores
|
||||
all but one updates. Using ``x.at[idx].add(y)`` will correctly
|
||||
apply all the updates to all indices.
|
||||
|
||||
.. list-table::
|
||||
:header-rows: 1
|
||||
|
||||
* - array.at syntax
|
||||
- In-place syntax
|
||||
* - ``x = x.at[idx].add(y)``
|
||||
- ``x[idx] += y``
|
||||
* - ``x = x.at[idx].subtract(y)``
|
||||
- ``x[idx] -= y``
|
||||
* - ``x = x.at[idx].multiply(y)``
|
||||
- ``x[idx] *= y``
|
||||
* - ``x = x.at[idx].divide(y)``
|
||||
- ``x[idx] /= y``
|
||||
* - ``x = x.at[idx].maximum(y)``
|
||||
- ``x[idx] = mx.maximum(x[idx], y)``
|
||||
* - ``x = x.at[idx].minimum(y)``
|
||||
- ``x[idx] = mx.minimum(x[idx], y)``
|
||||
)pbdoc")
|
||||
.def(
|
||||
"__len__",
|
||||
[](const array& a) {
|
||||
|
||||
Reference in New Issue
Block a user