mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-14 17:12:49 +08:00
Some fixes to typing (#1371)
* some fixes to typing * fix module reference * comment
This commit is contained in:
@@ -178,7 +178,7 @@ auto py_value_and_grad(
|
||||
msg << error_msg_tag << " The return value of the function "
|
||||
<< "whose gradient we want to compute should be either a "
|
||||
<< "scalar array or a tuple with the first value being a "
|
||||
<< "scalar array (Union[array, Tuple[array, Any, ...]]); but "
|
||||
<< "scalar array (Union[array, tuple[array, Any, ...]]); but "
|
||||
<< type_name_str(py_value_out) << " was returned.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
@@ -197,7 +197,7 @@ auto py_value_and_grad(
|
||||
msg << error_msg_tag << " The return value of the function "
|
||||
<< "whose gradient we want to compute should be either a "
|
||||
<< "scalar array or a tuple with the first value being a "
|
||||
<< "scalar array (Union[array, Tuple[array, Any, ...]]); but it "
|
||||
<< "scalar array (Union[array, tuple[array, Any, ...]]); but it "
|
||||
<< "was a tuple with the first value being of type "
|
||||
<< type_name_str(ret[0]) << " .";
|
||||
throw std::invalid_argument(msg.str());
|
||||
@@ -973,13 +973,13 @@ void init_transforms(nb::module_& m) {
|
||||
.def(
|
||||
nb::init<nb::callable>(),
|
||||
"f"_a,
|
||||
nb::sig("def __init__(self, f: callable)"))
|
||||
nb::sig("def __init__(self, f: Callable)"))
|
||||
.def("__call__", &PyCustomFunction::call_impl)
|
||||
.def(
|
||||
"vjp",
|
||||
&PyCustomFunction::set_vjp,
|
||||
"f"_a,
|
||||
nb::sig("def vjp(self, f_vjp: callable)"),
|
||||
nb::sig("def vjp(self, f: Callable)"),
|
||||
R"pbdoc(
|
||||
Define a custom vjp for the wrapped function.
|
||||
|
||||
@@ -1001,7 +1001,7 @@ void init_transforms(nb::module_& m) {
|
||||
"jvp",
|
||||
&PyCustomFunction::set_jvp,
|
||||
"f"_a,
|
||||
nb::sig("def jvp(self, f_jvp: callable)"),
|
||||
nb::sig("def jvp(self, f: Callable)"),
|
||||
R"pbdoc(
|
||||
Define a custom jvp for the wrapped function.
|
||||
|
||||
@@ -1021,7 +1021,7 @@ void init_transforms(nb::module_& m) {
|
||||
"vmap",
|
||||
&PyCustomFunction::set_vmap,
|
||||
"f"_a,
|
||||
nb::sig("def vmap(self, f_vmap: callable)"),
|
||||
nb::sig("def vmap(self, f: Callable)"),
|
||||
R"pbdoc(
|
||||
Define a custom vectorization transformation for the wrapped function.
|
||||
|
||||
@@ -1116,7 +1116,7 @@ void init_transforms(nb::module_& m) {
|
||||
"primals"_a,
|
||||
"tangents"_a,
|
||||
nb::sig(
|
||||
"def jvp(fun: callable, primals: List[array], tangents: List[array]) -> Tuple[List[array], List[array]]"),
|
||||
"def jvp(fun: Callable, primals: list[array], tangents: list[array]) -> tuple[list[array], list[array]]"),
|
||||
R"pbdoc(
|
||||
Compute the Jacobian-vector product.
|
||||
|
||||
@@ -1124,7 +1124,7 @@ void init_transforms(nb::module_& m) {
|
||||
at ``primals`` with the ``tangents``.
|
||||
|
||||
Args:
|
||||
fun (callable): A function which takes a variable number of :class:`array`
|
||||
fun (Callable): A function which takes a variable number of :class:`array`
|
||||
and returns a single :class:`array` or list of :class:`array`.
|
||||
primals (list(array)): A list of :class:`array` at which to
|
||||
evaluate the Jacobian.
|
||||
@@ -1155,7 +1155,7 @@ void init_transforms(nb::module_& m) {
|
||||
"primals"_a,
|
||||
"cotangents"_a,
|
||||
nb::sig(
|
||||
"def vjp(fun: callable, primals: List[array], cotangents: List[array]) -> Tuple[List[array], List[array]]"),
|
||||
"def vjp(fun: Callable, primals: list[array], cotangents: list[array]) -> tuple[list[array], list[array]]"),
|
||||
R"pbdoc(
|
||||
Compute the vector-Jacobian product.
|
||||
|
||||
@@ -1163,7 +1163,7 @@ void init_transforms(nb::module_& m) {
|
||||
function ``fun`` evaluated at ``primals``.
|
||||
|
||||
Args:
|
||||
fun (callable): A function which takes a variable number of :class:`array`
|
||||
fun (Callable): A function which takes a variable number of :class:`array`
|
||||
and returns a single :class:`array` or list of :class:`array`.
|
||||
primals (list(array)): A list of :class:`array` at which to
|
||||
evaluate the Jacobian.
|
||||
@@ -1189,7 +1189,7 @@ void init_transforms(nb::module_& m) {
|
||||
"argnums"_a = nb::none(),
|
||||
"argnames"_a = std::vector<std::string>{},
|
||||
nb::sig(
|
||||
"def value_and_grad(fun: callable, argnums: Optional[Union[int, List[int]]] = None, argnames: Union[str, List[str]] = []) -> callable"),
|
||||
"def value_and_grad(fun: Callable, argnums: Optional[Union[int, list[int]]] = None, argnames: Union[str, list[str]] = []) -> Callable"),
|
||||
R"pbdoc(
|
||||
Returns a function which computes the value and gradient of ``fun``.
|
||||
|
||||
@@ -1221,7 +1221,7 @@ void init_transforms(nb::module_& m) {
|
||||
(loss, mse, l1), grads = mx.value_and_grad(lasso)(params, inputs, targets)
|
||||
|
||||
Args:
|
||||
fun (callable): A function which takes a variable number of
|
||||
fun (Callable): A function which takes a variable number of
|
||||
:class:`array` or trees of :class:`array` and returns
|
||||
a scalar output :class:`array` or a tuple the first element
|
||||
of which should be a scalar :class:`array`.
|
||||
@@ -1235,7 +1235,7 @@ void init_transforms(nb::module_& m) {
|
||||
no gradients for keyword arguments by default.
|
||||
|
||||
Returns:
|
||||
callable: A function which returns a tuple where the first element
|
||||
Callable: A function which returns a tuple where the first element
|
||||
is the output of `fun` and the second element is the gradients w.r.t.
|
||||
the loss.
|
||||
)pbdoc");
|
||||
@@ -1257,12 +1257,12 @@ void init_transforms(nb::module_& m) {
|
||||
"argnums"_a = nb::none(),
|
||||
"argnames"_a = std::vector<std::string>{},
|
||||
nb::sig(
|
||||
"def grad(fun: callable, argnums: Optional[Union[int, List[int]]] = None, argnames: Union[str, List[str]] = []) -> callable"),
|
||||
"def grad(fun: Callable, argnums: Optional[Union[int, list[int]]] = None, argnames: Union[str, list[str]] = []) -> Callable"),
|
||||
R"pbdoc(
|
||||
Returns a function which computes the gradient of ``fun``.
|
||||
|
||||
Args:
|
||||
fun (callable): A function which takes a variable number of
|
||||
fun (Callable): A function which takes a variable number of
|
||||
:class:`array` or trees of :class:`array` and returns
|
||||
a scalar output :class:`array`.
|
||||
argnums (int or list(int), optional): Specify the index (or indices)
|
||||
@@ -1275,7 +1275,7 @@ void init_transforms(nb::module_& m) {
|
||||
no gradients for keyword arguments by default.
|
||||
|
||||
Returns:
|
||||
callable: A function which has the same input arguments as ``fun`` and
|
||||
Callable: A function which has the same input arguments as ``fun`` and
|
||||
returns the gradient(s).
|
||||
)pbdoc");
|
||||
m.def(
|
||||
@@ -1289,12 +1289,12 @@ void init_transforms(nb::module_& m) {
|
||||
"in_axes"_a = 0,
|
||||
"out_axes"_a = 0,
|
||||
nb::sig(
|
||||
"def vmap(fun: callable, in_axes: object = 0, out_axes: object = 0) -> callable"),
|
||||
"def vmap(fun: Callable, in_axes: object = 0, out_axes: object = 0) -> Callable"),
|
||||
R"pbdoc(
|
||||
Returns a vectorized version of ``fun``.
|
||||
|
||||
Args:
|
||||
fun (callable): A function which takes a variable number of
|
||||
fun (Callable): A function which takes a variable number of
|
||||
:class:`array` or a tree of :class:`array` and returns
|
||||
a variable number of :class:`array` or a tree of :class:`array`.
|
||||
in_axes (int, optional): An integer or a valid prefix tree of the
|
||||
@@ -1307,7 +1307,7 @@ void init_transforms(nb::module_& m) {
|
||||
Defaults to ``0``.
|
||||
|
||||
Returns:
|
||||
callable: The vectorized function.
|
||||
Callable: The vectorized function.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"export_to_dot",
|
||||
@@ -1367,11 +1367,13 @@ void init_transforms(nb::module_& m) {
|
||||
"inputs"_a = nb::none(),
|
||||
"outputs"_a = nb::none(),
|
||||
"shapeless"_a = false,
|
||||
nb::sig(
|
||||
"def compile(fun: Callable, inputs: Optional[object] = None, outputs: Optional[object] = None, shapeless: bool = False) -> Callable"),
|
||||
R"pbdoc(
|
||||
Returns a compiled function which produces the same output as ``fun``.
|
||||
|
||||
Args:
|
||||
fun (callable): A function which takes a variable number of
|
||||
fun (Callable): A function which takes a variable number of
|
||||
:class:`array` or trees of :class:`array` and returns
|
||||
a variable number of :class:`array` or trees of :class:`array`.
|
||||
inputs (list or dict, optional): These inputs will be captured during
|
||||
@@ -1392,7 +1394,7 @@ void init_transforms(nb::module_& m) {
|
||||
``shapeless`` set to ``True``. Default: ``False``
|
||||
|
||||
Returns:
|
||||
callable: A compiled function which has the same input arguments
|
||||
Callable: A compiled function which has the same input arguments
|
||||
as ``fun`` and returns the the same output(s).
|
||||
)pbdoc");
|
||||
m.def(
|
||||
|
||||
Reference in New Issue
Block a user