mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-15 01:19:21 +08:00
Fix stub generation, change graph exporting for arrows to go to outputs (#455)
This commit is contained in:
@@ -438,6 +438,9 @@ auto py_vmap(
|
||||
}
|
||||
|
||||
void init_transforms(py::module_& m) {
|
||||
py::options options;
|
||||
options.disable_function_signatures();
|
||||
|
||||
m.def(
|
||||
"eval",
|
||||
[](const py::args& args) {
|
||||
@@ -445,6 +448,8 @@ void init_transforms(py::module_& m) {
|
||||
eval(arrays);
|
||||
},
|
||||
R"pbdoc(
|
||||
eval(*args) -> None
|
||||
|
||||
Evaluate an :class:`array` or tree of :class:`array`.
|
||||
|
||||
Args:
|
||||
@@ -476,6 +481,9 @@ void init_transforms(py::module_& m) {
|
||||
"primals"_a,
|
||||
"tangents"_a,
|
||||
R"pbdoc(
|
||||
jvp(fun: function, primals: List[array], tangents: List[array]) -> Tuple[List[array], List[array]]
|
||||
|
||||
|
||||
Compute the Jacobian-vector product.
|
||||
|
||||
This computes the product of the Jacobian of a function ``fun`` evaluated
|
||||
@@ -517,6 +525,8 @@ void init_transforms(py::module_& m) {
|
||||
"primals"_a,
|
||||
"cotangents"_a,
|
||||
R"pbdoc(
|
||||
vjp(fun: function, primals: List[array], cotangents: List[array]) -> Tuple[List[array], List[array]]
|
||||
|
||||
Compute the vector-Jacobian product.
|
||||
|
||||
Computes the product of the ``cotangents`` with the Jacobian of a
|
||||
@@ -549,6 +559,8 @@ void init_transforms(py::module_& m) {
|
||||
"argnums"_a = std::nullopt,
|
||||
"argnames"_a = std::vector<std::string>{},
|
||||
R"pbdoc(
|
||||
value_and_grad(fun: function, argnums: Optional[Union[int, List[int]]] = None, argnames: Union[str, List[str]] = []) -> function
|
||||
|
||||
Returns a function which computes the value and gradient of ``fun``.
|
||||
|
||||
The function passed to :func:`value_and_grad` should return either
|
||||
@@ -615,6 +627,8 @@ void init_transforms(py::module_& m) {
|
||||
"argnums"_a = std::nullopt,
|
||||
"argnames"_a = std::vector<std::string>{},
|
||||
R"pbdoc(
|
||||
grad(fun: function, argnums: Optional[Union[int, List[int]]] = None, argnames: Union[str, List[str]] = []) -> function
|
||||
|
||||
Returns a function which computes the gradient of ``fun``.
|
||||
|
||||
Args:
|
||||
@@ -645,6 +659,8 @@ void init_transforms(py::module_& m) {
|
||||
"in_axes"_a = 0,
|
||||
"out_axes"_a = 0,
|
||||
R"pbdoc(
|
||||
vmap(fun: function, in_axes: object = 0, out_axes: object = 0) -> function
|
||||
|
||||
Returns a vectorized version of ``fun``.
|
||||
|
||||
Args:
|
||||
@@ -670,6 +686,8 @@ void init_transforms(py::module_& m) {
|
||||
simplify(arrays);
|
||||
},
|
||||
R"pbdoc(
|
||||
simplify(*args) -> None
|
||||
|
||||
Simplify the graph that computes the arrays.
|
||||
|
||||
Run a few fast graph simplification operations to reuse computation and
|
||||
|
||||
Reference in New Issue
Block a user