Fix stub generation, change graph exporting for arrows to go to outputs (#455)

This commit is contained in:
Awni Hannun
2024-01-14 14:06:16 -08:00
committed by GitHub
parent 6e81c3e164
commit 41cc7bdfdb
5 changed files with 49 additions and 47 deletions

View File

@@ -532,25 +532,12 @@ void init_array(py::module_& m) {
m.attr("bfloat16") = py::cast(bfloat16);
m.attr("complex64") = py::cast(complex64);
py::class_<ArrayAt>(
auto array_at_class = py::class_<ArrayAt>(
m,
"_ArrayAt",
R"pbdoc(
A helper object to apply updates at specific indices.
)pbdoc")
.def(
py::init([](const array& x) { return ArrayAt(x); }),
"x"_a,
R"pbdoc(
__init__(self, x: array)
)pbdoc")
.def("__getitem__", &ArrayAt::set_indices, "indices"_a)
.def("add", &ArrayAt::add, "value"_a)
.def("subtract", &ArrayAt::subtract, "value"_a)
.def("multiply", &ArrayAt::multiply, "value"_a)
.def("divide", &ArrayAt::divide, "value"_a)
.def("maximum", &ArrayAt::maximum, "value"_a)
.def("minimum", &ArrayAt::minimum, "value"_a);
)pbdoc");
auto array_class = py::class_<array>(
m,
@@ -573,6 +560,21 @@ void init_array(py::module_& m) {
)pbdoc");
}
array_at_class
.def(
py::init([](const array& x) { return ArrayAt(x); }),
"x"_a,
R"pbdoc(
__init__(self, x: array)
)pbdoc")
.def("__getitem__", &ArrayAt::set_indices, "indices"_a)
.def("add", &ArrayAt::add, "value"_a)
.def("subtract", &ArrayAt::subtract, "value"_a)
.def("multiply", &ArrayAt::multiply, "value"_a)
.def("divide", &ArrayAt::divide, "value"_a)
.def("maximum", &ArrayAt::maximum, "value"_a)
.def("minimum", &ArrayAt::minimum, "value"_a);
array_class
.def_buffer([](array& a) {
// Eval if not already evaled
@@ -680,17 +682,17 @@ void init_array(py::module_& m) {
* - array.at syntax
- In-place syntax
* - ``x = x.at[idx].add(y)``
* - ``x = x.at[idx].add(y)``
- ``x[idx] += y``
* - ``x = x.at[idx].subtract(y)``
* - ``x = x.at[idx].subtract(y)``
- ``x[idx] -= y``
* - ``x = x.at[idx].multiply(y)``
* - ``x = x.at[idx].multiply(y)``
- ``x[idx] *= y``
* - ``x = x.at[idx].divide(y)``
* - ``x = x.at[idx].divide(y)``
- ``x[idx] /= y``
* - ``x = x.at[idx].maximum(y)``
* - ``x = x.at[idx].maximum(y)``
- ``x[idx] = mx.maximum(x[idx], y)``
* - ``x = x.at[idx].minimum(y)``
* - ``x = x.at[idx].minimum(y)``
- ``x[idx] = mx.minimum(x[idx], y)``
)pbdoc")
.def(