mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-14 17:12:49 +08:00
Fix stub generation, change graph exporting for arrows to go to outputs (#455)
This commit is contained in:
@@ -532,25 +532,12 @@ void init_array(py::module_& m) {
|
||||
m.attr("bfloat16") = py::cast(bfloat16);
|
||||
m.attr("complex64") = py::cast(complex64);
|
||||
|
||||
py::class_<ArrayAt>(
|
||||
auto array_at_class = py::class_<ArrayAt>(
|
||||
m,
|
||||
"_ArrayAt",
|
||||
R"pbdoc(
|
||||
A helper object to apply updates at specific indices.
|
||||
)pbdoc")
|
||||
.def(
|
||||
py::init([](const array& x) { return ArrayAt(x); }),
|
||||
"x"_a,
|
||||
R"pbdoc(
|
||||
__init__(self, x: array)
|
||||
)pbdoc")
|
||||
.def("__getitem__", &ArrayAt::set_indices, "indices"_a)
|
||||
.def("add", &ArrayAt::add, "value"_a)
|
||||
.def("subtract", &ArrayAt::subtract, "value"_a)
|
||||
.def("multiply", &ArrayAt::multiply, "value"_a)
|
||||
.def("divide", &ArrayAt::divide, "value"_a)
|
||||
.def("maximum", &ArrayAt::maximum, "value"_a)
|
||||
.def("minimum", &ArrayAt::minimum, "value"_a);
|
||||
)pbdoc");
|
||||
|
||||
auto array_class = py::class_<array>(
|
||||
m,
|
||||
@@ -573,6 +560,21 @@ void init_array(py::module_& m) {
|
||||
)pbdoc");
|
||||
}
|
||||
|
||||
array_at_class
|
||||
.def(
|
||||
py::init([](const array& x) { return ArrayAt(x); }),
|
||||
"x"_a,
|
||||
R"pbdoc(
|
||||
__init__(self, x: array)
|
||||
)pbdoc")
|
||||
.def("__getitem__", &ArrayAt::set_indices, "indices"_a)
|
||||
.def("add", &ArrayAt::add, "value"_a)
|
||||
.def("subtract", &ArrayAt::subtract, "value"_a)
|
||||
.def("multiply", &ArrayAt::multiply, "value"_a)
|
||||
.def("divide", &ArrayAt::divide, "value"_a)
|
||||
.def("maximum", &ArrayAt::maximum, "value"_a)
|
||||
.def("minimum", &ArrayAt::minimum, "value"_a);
|
||||
|
||||
array_class
|
||||
.def_buffer([](array& a) {
|
||||
// Eval if not already evaled
|
||||
@@ -680,17 +682,17 @@ void init_array(py::module_& m) {
|
||||
|
||||
* - array.at syntax
|
||||
- In-place syntax
|
||||
* - ``x = x.at[idx].add(y)``
|
||||
* - ``x = x.at[idx].add(y)``
|
||||
- ``x[idx] += y``
|
||||
* - ``x = x.at[idx].subtract(y)``
|
||||
* - ``x = x.at[idx].subtract(y)``
|
||||
- ``x[idx] -= y``
|
||||
* - ``x = x.at[idx].multiply(y)``
|
||||
* - ``x = x.at[idx].multiply(y)``
|
||||
- ``x[idx] *= y``
|
||||
* - ``x = x.at[idx].divide(y)``
|
||||
* - ``x = x.at[idx].divide(y)``
|
||||
- ``x[idx] /= y``
|
||||
* - ``x = x.at[idx].maximum(y)``
|
||||
* - ``x = x.at[idx].maximum(y)``
|
||||
- ``x[idx] = mx.maximum(x[idx], y)``
|
||||
* - ``x = x.at[idx].minimum(y)``
|
||||
* - ``x = x.at[idx].minimum(y)``
|
||||
- ``x[idx] = mx.minimum(x[idx], y)``
|
||||
)pbdoc")
|
||||
.def(
|
||||
|
||||
Reference in New Issue
Block a user