Fix stub generation, change graph exporting for arrows to go to outputs (#455)

This commit is contained in:
Awni Hannun
2024-01-14 14:06:16 -08:00
committed by GitHub
parent 6e81c3e164
commit 41cc7bdfdb
5 changed files with 49 additions and 47 deletions

View File

@@ -438,6 +438,9 @@ auto py_vmap(
}
void init_transforms(py::module_& m) {
py::options options;
options.disable_function_signatures();
m.def(
"eval",
[](const py::args& args) {
@@ -445,6 +448,8 @@ void init_transforms(py::module_& m) {
eval(arrays);
},
R"pbdoc(
eval(*args) -> None
Evaluate an :class:`array` or tree of :class:`array`.
Args:
@@ -476,6 +481,9 @@ void init_transforms(py::module_& m) {
"primals"_a,
"tangents"_a,
R"pbdoc(
jvp(fun: function, primals: List[array], tangents: List[array]) -> Tuple[List[array], List[array]]
Compute the Jacobian-vector product.
This computes the product of the Jacobian of a function ``fun`` evaluated
@@ -517,6 +525,8 @@ void init_transforms(py::module_& m) {
"primals"_a,
"cotangents"_a,
R"pbdoc(
vjp(fun: function, primals: List[array], cotangents: List[array]) -> Tuple[List[array], List[array]]
Compute the vector-Jacobian product.
Computes the product of the ``cotangents`` with the Jacobian of a
@@ -549,6 +559,8 @@ void init_transforms(py::module_& m) {
"argnums"_a = std::nullopt,
"argnames"_a = std::vector<std::string>{},
R"pbdoc(
value_and_grad(fun: function, argnums: Optional[Union[int, List[int]]] = None, argnames: Union[str, List[str]] = []) -> function
Returns a function which computes the value and gradient of ``fun``.
The function passed to :func:`value_and_grad` should return either
@@ -615,6 +627,8 @@ void init_transforms(py::module_& m) {
"argnums"_a = std::nullopt,
"argnames"_a = std::vector<std::string>{},
R"pbdoc(
grad(fun: function, argnums: Optional[Union[int, List[int]]] = None, argnames: Union[str, List[str]] = []) -> function
Returns a function which computes the gradient of ``fun``.
Args:
@@ -645,6 +659,8 @@ void init_transforms(py::module_& m) {
"in_axes"_a = 0,
"out_axes"_a = 0,
R"pbdoc(
vmap(fun: function, in_axes: object = 0, out_axes: object = 0) -> function
Returns a vectorized version of ``fun``.
Args:
@@ -670,6 +686,8 @@ void init_transforms(py::module_& m) {
simplify(arrays);
},
R"pbdoc(
simplify(*args) -> None
Simplify the graph that computes the arrays.
Run a few fast graph simplification operations to reuse computation and