Some fixes to typing (#1371)

* some fixes to typing

* fix module reference

* comment
This commit is contained in:
Awni Hannun
2024-08-28 11:16:19 -07:00
committed by GitHub
parent bd47e1f066
commit 291cf40aca
15 changed files with 152 additions and 145 deletions

View File

@@ -178,7 +178,7 @@ auto py_value_and_grad(
msg << error_msg_tag << " The return value of the function "
<< "whose gradient we want to compute should be either a "
<< "scalar array or a tuple with the first value being a "
<< "scalar array (Union[array, Tuple[array, Any, ...]]); but "
<< "scalar array (Union[array, tuple[array, Any, ...]]); but "
<< type_name_str(py_value_out) << " was returned.";
throw std::invalid_argument(msg.str());
}
@@ -197,7 +197,7 @@ auto py_value_and_grad(
msg << error_msg_tag << " The return value of the function "
<< "whose gradient we want to compute should be either a "
<< "scalar array or a tuple with the first value being a "
<< "scalar array (Union[array, Tuple[array, Any, ...]]); but it "
<< "scalar array (Union[array, tuple[array, Any, ...]]); but it "
<< "was a tuple with the first value being of type "
<< type_name_str(ret[0]) << " .";
throw std::invalid_argument(msg.str());
@@ -973,13 +973,13 @@ void init_transforms(nb::module_& m) {
.def(
nb::init<nb::callable>(),
"f"_a,
nb::sig("def __init__(self, f: callable)"))
nb::sig("def __init__(self, f: Callable)"))
.def("__call__", &PyCustomFunction::call_impl)
.def(
"vjp",
&PyCustomFunction::set_vjp,
"f"_a,
nb::sig("def vjp(self, f_vjp: callable)"),
nb::sig("def vjp(self, f: Callable)"),
R"pbdoc(
Define a custom vjp for the wrapped function.
@@ -1001,7 +1001,7 @@ void init_transforms(nb::module_& m) {
"jvp",
&PyCustomFunction::set_jvp,
"f"_a,
nb::sig("def jvp(self, f_jvp: callable)"),
nb::sig("def jvp(self, f: Callable)"),
R"pbdoc(
Define a custom jvp for the wrapped function.
@@ -1021,7 +1021,7 @@ void init_transforms(nb::module_& m) {
"vmap",
&PyCustomFunction::set_vmap,
"f"_a,
nb::sig("def vmap(self, f_vmap: callable)"),
nb::sig("def vmap(self, f: Callable)"),
R"pbdoc(
Define a custom vectorization transformation for the wrapped function.
@@ -1116,7 +1116,7 @@ void init_transforms(nb::module_& m) {
"primals"_a,
"tangents"_a,
nb::sig(
"def jvp(fun: callable, primals: List[array], tangents: List[array]) -> Tuple[List[array], List[array]]"),
"def jvp(fun: Callable, primals: list[array], tangents: list[array]) -> tuple[list[array], list[array]]"),
R"pbdoc(
Compute the Jacobian-vector product.
@@ -1124,7 +1124,7 @@ void init_transforms(nb::module_& m) {
at ``primals`` with the ``tangents``.
Args:
fun (callable): A function which takes a variable number of :class:`array`
fun (Callable): A function which takes a variable number of :class:`array`
and returns a single :class:`array` or list of :class:`array`.
primals (list(array)): A list of :class:`array` at which to
evaluate the Jacobian.
@@ -1155,7 +1155,7 @@ void init_transforms(nb::module_& m) {
"primals"_a,
"cotangents"_a,
nb::sig(
"def vjp(fun: callable, primals: List[array], cotangents: List[array]) -> Tuple[List[array], List[array]]"),
"def vjp(fun: Callable, primals: list[array], cotangents: list[array]) -> tuple[list[array], list[array]]"),
R"pbdoc(
Compute the vector-Jacobian product.
@@ -1163,7 +1163,7 @@ void init_transforms(nb::module_& m) {
function ``fun`` evaluated at ``primals``.
Args:
fun (callable): A function which takes a variable number of :class:`array`
fun (Callable): A function which takes a variable number of :class:`array`
and returns a single :class:`array` or list of :class:`array`.
primals (list(array)): A list of :class:`array` at which to
evaluate the Jacobian.
@@ -1189,7 +1189,7 @@ void init_transforms(nb::module_& m) {
"argnums"_a = nb::none(),
"argnames"_a = std::vector<std::string>{},
nb::sig(
"def value_and_grad(fun: callable, argnums: Optional[Union[int, List[int]]] = None, argnames: Union[str, List[str]] = []) -> callable"),
"def value_and_grad(fun: Callable, argnums: Optional[Union[int, list[int]]] = None, argnames: Union[str, list[str]] = []) -> Callable"),
R"pbdoc(
Returns a function which computes the value and gradient of ``fun``.
@@ -1221,7 +1221,7 @@ void init_transforms(nb::module_& m) {
(loss, mse, l1), grads = mx.value_and_grad(lasso)(params, inputs, targets)
Args:
fun (callable): A function which takes a variable number of
fun (Callable): A function which takes a variable number of
:class:`array` or trees of :class:`array` and returns
a scalar output :class:`array` or a tuple the first element
of which should be a scalar :class:`array`.
@@ -1235,7 +1235,7 @@ void init_transforms(nb::module_& m) {
no gradients for keyword arguments by default.
Returns:
callable: A function which returns a tuple where the first element
Callable: A function which returns a tuple where the first element
is the output of `fun` and the second element is the gradients w.r.t.
the loss.
)pbdoc");
@@ -1257,12 +1257,12 @@ void init_transforms(nb::module_& m) {
"argnums"_a = nb::none(),
"argnames"_a = std::vector<std::string>{},
nb::sig(
"def grad(fun: callable, argnums: Optional[Union[int, List[int]]] = None, argnames: Union[str, List[str]] = []) -> callable"),
"def grad(fun: Callable, argnums: Optional[Union[int, list[int]]] = None, argnames: Union[str, list[str]] = []) -> Callable"),
R"pbdoc(
Returns a function which computes the gradient of ``fun``.
Args:
fun (callable): A function which takes a variable number of
fun (Callable): A function which takes a variable number of
:class:`array` or trees of :class:`array` and returns
a scalar output :class:`array`.
argnums (int or list(int), optional): Specify the index (or indices)
@@ -1275,7 +1275,7 @@ void init_transforms(nb::module_& m) {
no gradients for keyword arguments by default.
Returns:
callable: A function which has the same input arguments as ``fun`` and
Callable: A function which has the same input arguments as ``fun`` and
returns the gradient(s).
)pbdoc");
m.def(
@@ -1289,12 +1289,12 @@ void init_transforms(nb::module_& m) {
"in_axes"_a = 0,
"out_axes"_a = 0,
nb::sig(
"def vmap(fun: callable, in_axes: object = 0, out_axes: object = 0) -> callable"),
"def vmap(fun: Callable, in_axes: object = 0, out_axes: object = 0) -> Callable"),
R"pbdoc(
Returns a vectorized version of ``fun``.
Args:
fun (callable): A function which takes a variable number of
fun (Callable): A function which takes a variable number of
:class:`array` or a tree of :class:`array` and returns
a variable number of :class:`array` or a tree of :class:`array`.
in_axes (int, optional): An integer or a valid prefix tree of the
@@ -1307,7 +1307,7 @@ void init_transforms(nb::module_& m) {
Defaults to ``0``.
Returns:
callable: The vectorized function.
Callable: The vectorized function.
)pbdoc");
m.def(
"export_to_dot",
@@ -1367,11 +1367,13 @@ void init_transforms(nb::module_& m) {
"inputs"_a = nb::none(),
"outputs"_a = nb::none(),
"shapeless"_a = false,
nb::sig(
"def compile(fun: Callable, inputs: Optional[object] = None, outputs: Optional[object] = None, shapeless: bool = False) -> Callable"),
R"pbdoc(
Returns a compiled function which produces the same output as ``fun``.
Args:
fun (callable): A function which takes a variable number of
fun (Callable): A function which takes a variable number of
:class:`array` or trees of :class:`array` and returns
a variable number of :class:`array` or trees of :class:`array`.
inputs (list or dict, optional): These inputs will be captured during
@@ -1392,7 +1394,7 @@ void init_transforms(nb::module_& m) {
``shapeless`` set to ``True``. Default: ``False``
Returns:
callable: A compiled function which has the same input arguments
Callable: A compiled function which has the same input arguments
as ``fun`` and returns the the same output(s).
)pbdoc");
m.def(