mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-06 03:58:12 +08:00
Some fixes to typing (#1371)
* some fixes to typing * fix module reference * comment
This commit is contained in:
@@ -9,6 +9,7 @@
|
||||
#include <nanobind/stl/string.h>
|
||||
#include <nanobind/stl/variant.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
#include <nanobind/typing.h>
|
||||
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "python/src/buffer.h"
|
||||
@@ -113,6 +114,7 @@ void init_array(nb::module_& m) {
|
||||
.def("__hash__", [](const Dtype& t) {
|
||||
return static_cast<int64_t>(t.val);
|
||||
});
|
||||
|
||||
m.attr("bool_") = nb::cast(bool_);
|
||||
m.attr("uint8") = nb::cast(uint8);
|
||||
m.attr("uint16") = nb::cast(uint16);
|
||||
@@ -177,7 +179,7 @@ void init_array(nb::module_& m) {
|
||||
.export_values();
|
||||
nb::class_<ArrayAt>(
|
||||
m,
|
||||
"_ArrayAt",
|
||||
"ArrayAt",
|
||||
R"pbdoc(
|
||||
A helper object to apply updates at specific indices.
|
||||
)pbdoc")
|
||||
@@ -195,7 +197,7 @@ void init_array(nb::module_& m) {
|
||||
|
||||
nb::class_<ArrayPythonIterator>(
|
||||
m,
|
||||
"_ArrayIterator",
|
||||
"ArrayIterator",
|
||||
R"pbdoc(
|
||||
A helper object to iterate over the 1st dimension of an array.
|
||||
)pbdoc")
|
||||
|
||||
@@ -229,34 +229,35 @@ void init_fast(nb::module_& parent_module) {
|
||||
Returns:
|
||||
Callable ``metal_kernel``.
|
||||
|
||||
.. code-block:: python
|
||||
Example:
|
||||
|
||||
def exp_elementwise(a: mx.array):
|
||||
source = """
|
||||
uint elem = thread_position_in_grid.x;
|
||||
T tmp = inp[elem];
|
||||
out[elem] = metal::exp(tmp);
|
||||
"""
|
||||
.. code-block:: python
|
||||
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="myexp",
|
||||
source=source
|
||||
)
|
||||
outputs = kernel(
|
||||
inputs={"inp": a},
|
||||
template={"T": mx.float32},
|
||||
grid=(a.size, 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
output_shapes={"out": a.shape},
|
||||
output_dtypes={"out": a.dtype},
|
||||
verbose=True,
|
||||
)
|
||||
return outputs["out"]
|
||||
def exp_elementwise(a: mx.array):
|
||||
source = '''
|
||||
uint elem = thread_position_in_grid.x;
|
||||
T tmp = inp[elem];
|
||||
out[elem] = metal::exp(tmp);
|
||||
'''
|
||||
|
||||
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
|
||||
b = exp_elementwise(a)
|
||||
assert mx.allclose(b, mx.exp(a))
|
||||
kernel = mx.fast.metal_kernel(
|
||||
name="myexp",
|
||||
source=source
|
||||
)
|
||||
outputs = kernel(
|
||||
inputs={"inp": a},
|
||||
template={"T": mx.float32},
|
||||
grid=(a.size, 1, 1),
|
||||
threadgroup=(256, 1, 1),
|
||||
output_shapes={"out": a.shape},
|
||||
output_dtypes={"out": a.dtype},
|
||||
verbose=True,
|
||||
)
|
||||
return outputs["out"]
|
||||
|
||||
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
|
||||
b = exp_elementwise(a)
|
||||
assert mx.allclose(b, mx.exp(a))
|
||||
)pbdoc")
|
||||
.def(
|
||||
"__call__",
|
||||
|
||||
@@ -63,7 +63,7 @@ void init_linalg(nb::module_& parent_module) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def norm(a: array, /, ord: Union[None, scalar, str] = None, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def norm(a: array, /, ord: Union[None, int, float, str] = None, axis: Union[None, int, list[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Matrix or vector norm.
|
||||
|
||||
@@ -74,7 +74,7 @@ void init_linalg(nb::module_& parent_module) {
|
||||
a (array): Input array. If ``axis`` is ``None``, ``a`` must be 1-D or 2-D,
|
||||
unless ``ord`` is ``None``. If both ``axis`` and ``ord`` are ``None``, the
|
||||
2-norm of ``a.flatten`` will be returned.
|
||||
ord (scalar or str, optional): Order of the norm (see table under ``Notes``).
|
||||
ord (int, float or str, optional): Order of the norm (see table under ``Notes``).
|
||||
If ``None``, the 2-norm (or Frobenius norm for matrices) will be computed
|
||||
along the given ``axis``. Default: ``None``.
|
||||
axis (int or list(int), optional): If ``axis`` is an integer, it specifies the
|
||||
@@ -187,7 +187,7 @@ void init_linalg(nb::module_& parent_module) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def qr(a: array, *, stream: Union[None, Stream, Device] = None) -> (array, array)"),
|
||||
"def qr(a: array, *, stream: Union[None, Stream, Device] = None) -> tuple(array, array)"),
|
||||
R"pbdoc(
|
||||
The QR factorization of the input matrix.
|
||||
|
||||
@@ -220,7 +220,7 @@ void init_linalg(nb::module_& parent_module) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def svd(a: array, *, stream: Union[None, Stream, Device] = None) -> (array, array, array)"),
|
||||
"def svd(a: array, *, stream: Union[None, Stream, Device] = None) -> tuple(array, array, array)"),
|
||||
R"pbdoc(
|
||||
The Singular Value Decomposition (SVD) of the input matrix.
|
||||
|
||||
|
||||
@@ -1360,7 +1360,7 @@ void init_ops(nb::module_& m) {
|
||||
"dtype"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def arange(stop : Union[int, float], step : Union[None, int, float], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"));
|
||||
"def arange(stop : Union[int, float], step : Union[None, int, float] = None, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"));
|
||||
m.def(
|
||||
"linspace",
|
||||
[](Scalar start,
|
||||
@@ -2695,7 +2695,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def concatenate(arrays: List[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def concatenate(arrays: list[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Concatenate the arrays along the given axis.
|
||||
|
||||
@@ -2723,7 +2723,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def concat(arrays: List[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def concat(arrays: list[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
See :func:`concatenate`.
|
||||
)pbdoc");
|
||||
@@ -2743,7 +2743,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def stack(arrays: List[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def stack(arrays: list[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Stacks the arrays along a new axis.
|
||||
|
||||
@@ -2770,7 +2770,7 @@ void init_ops(nb::module_& m) {
|
||||
"indexing"_a = "xy",
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def meshgrid(*arrays: array, sparse: Optional[bool] = false, indexing: Optional[str] = 'xy', stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def meshgrid(*arrays: array, sparse: Optional[bool] = False, indexing: Optional[str] = 'xy', stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Generate multidimensional coordinate grids from 1-D coordinate arrays
|
||||
|
||||
@@ -2889,7 +2889,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def pad(a: array, pad_width: Union[int, Tuple[int], Tuple[int, int], List[Tuple[int, int]]], mode: Literal['constant', 'edge'] = 'constant', constant_values: Union[scalar, array] = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def pad(a: array, pad_width: Union[int, tuple[int], tuple[int, int], list[tuple[int, int]]], mode: Literal['constant', 'edge'] = 'constant', constant_values: Union[scalar, array] = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Pad an array with a constant value
|
||||
|
||||
@@ -3291,7 +3291,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def conv2d(input: array, weight: array, /, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def conv2d(input: array, weight: array, /, stride: Union[int, tuple[int, int]] = 1, padding: Union[int, tuple[int, int]] = 0, dilation: Union[int, tuple[int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
2D convolution over an input with several channels
|
||||
|
||||
@@ -3361,7 +3361,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def conv3d(input: array, weight: array, /, stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[int, Tuple[int, int, int]] = 0, dilation: Union[int, Tuple[int, int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def conv3d(input: array, weight: array, /, stride: Union[int, tuple[int, int, int]] = 1, padding: Union[int, tuple[int, int, int]] = 0, dilation: Union[int, tuple[int, int, int]] = 1, groups: int = 1, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
3D convolution over an input with several channels
|
||||
|
||||
@@ -3460,7 +3460,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def conv_general(input: array, weight: array, /, stride: Union[int, Sequence[int]] = 1, padding: Union[int, Sequence[int], Tuple[Sequence[int], Sequence[int]]] = 0, kernel_dilation: Union[int, Sequence[int]] = 1, input_dilation: Union[int, Sequence[int]] = 1, groups: int = 1, flip: bool = false, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def conv_general(input: array, weight: array, /, stride: Union[int, Sequence[int]] = 1, padding: Union[int, Sequence[int], tuple[Sequence[int], Sequence[int]]] = 0, kernel_dilation: Union[int, Sequence[int]] = 1, input_dilation: Union[int, Sequence[int]] = 1, groups: int = 1, flip: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
General convolution over an input with several channels
|
||||
|
||||
@@ -3560,7 +3560,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def load(file: str, /, format: Optional[str] = None, return_metadata: bool = False, *, stream: Union[None, Stream, Device] = None) -> Union[array, Dict[str, array]]"),
|
||||
"def load(file: str, /, format: Optional[str] = None, return_metadata: bool = False, *, stream: Union[None, Stream, Device] = None) -> Union[array, dict[str, array]]"),
|
||||
R"pbdoc(
|
||||
Load array(s) from a binary file.
|
||||
|
||||
@@ -3594,7 +3594,7 @@ void init_ops(nb::module_& m) {
|
||||
"arrays"_a,
|
||||
"metadata"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def save_safetensors(file: str, arrays: Dict[str, array], metadata: Optional[Dict[str, str]] = None)"),
|
||||
"def save_safetensors(file: str, arrays: dict[str, array], metadata: Optional[dict[str, str]] = None)"),
|
||||
R"pbdoc(
|
||||
Save array(s) to a binary file in ``.safetensors`` format.
|
||||
|
||||
@@ -3615,7 +3615,7 @@ void init_ops(nb::module_& m) {
|
||||
"arrays"_a,
|
||||
"metadata"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def save_gguf(file: str, arrays: Dict[str, array], metadata: Dict[str, Union[array, str, List[str]]])"),
|
||||
"def save_gguf(file: str, arrays: dict[str, array], metadata: dict[str, Union[array, str, list[str]]])"),
|
||||
R"pbdoc(
|
||||
Save array(s) to a binary file in ``.gguf`` format.
|
||||
|
||||
@@ -3769,7 +3769,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def quantize(w: array, /, group_size: int = 64, bits : int = 4, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array, array]"),
|
||||
"def quantize(w: array, /, group_size: int = 64, bits : int = 4, *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"),
|
||||
R"pbdoc(
|
||||
Quantize the matrix ``w`` using ``bits`` bits per element.
|
||||
|
||||
@@ -3924,7 +3924,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def tensordot(a: array, b: array, /, axes: Union[int, List[Sequence[int]]] = 2, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def tensordot(a: array, b: array, /, axes: Union[int, list[Sequence[int]]] = 2, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Compute the tensor dot product along the specified axes.
|
||||
|
||||
@@ -4046,7 +4046,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def block_masked_mm(a: array, b: array, /, block_size: int = 64, mask_out: array, mask_lhs: array, mask_rhs: array, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def block_masked_mm(a: array, b: array, /, block_size: int = 64, mask_out: Optional[array] = None, mask_lhs: Optional[array] = None, mask_rhs: Optional[array] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Matrix multiplication with block masking.
|
||||
|
||||
@@ -4189,7 +4189,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def trace(a: array, /, offset: int = 0, axis1: int = 0, axis2: int = 1, dtype = Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def trace(a: array, /, offset: int = 0, axis1: int = 0, axis2: int = 1, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Return the sum along a specified diagonal in the given array.
|
||||
|
||||
@@ -4218,7 +4218,7 @@ void init_ops(nb::module_& m) {
|
||||
"arys"_a,
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def atleast_1d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, List[array]]"),
|
||||
"def atleast_1d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, list[array]]"),
|
||||
R"pbdoc(
|
||||
Convert all arrays to have at least one dimension.
|
||||
|
||||
@@ -4240,7 +4240,7 @@ void init_ops(nb::module_& m) {
|
||||
"arys"_a,
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def atleast_2d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, List[array]]"),
|
||||
"def atleast_2d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, list[array]]"),
|
||||
R"pbdoc(
|
||||
Convert all arrays to have at least two dimensions.
|
||||
|
||||
@@ -4262,7 +4262,7 @@ void init_ops(nb::module_& m) {
|
||||
"arys"_a,
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def atleast_3d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, List[array]]"),
|
||||
"def atleast_3d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, list[array]]"),
|
||||
R"pbdoc(
|
||||
Convert all arrays to have at least three dimensions.
|
||||
|
||||
@@ -4511,7 +4511,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def hadamard_transform(a: array, Optional[float] scale = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def hadamard_transform(a: array, scale: Optional[float] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Perform the Walsh-Hadamard transform along the final axis.
|
||||
|
||||
@@ -4575,7 +4575,7 @@ void init_ops(nb::module_& m) {
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def einsum(subscripts: str, *operands, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def einsum(subscripts: str, *operands, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
|
||||
Perform the Einstein summation convention on the operands.
|
||||
|
||||
@@ -93,7 +93,7 @@ void init_random(nb::module_& parent_module) {
|
||||
"num"_a = 2,
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def split(key: array, num: int = 2, stream: Union[None, Stream, Device] = None) -> array)"),
|
||||
"def split(key: array, num: int = 2, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Split a PRNG key into sub keys.
|
||||
|
||||
@@ -321,7 +321,7 @@ void init_random(nb::module_& parent_module) {
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def truncated_normal(lower: Union[scalar, array], upper: Union[scalar, array], shape: Optional[Sequence[int]] = None, dtype: float32, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def truncated_normal(lower: Union[scalar, array], upper: Union[scalar, array], shape: Optional[Sequence[int]] = None, dtype: Optional[Dtype] = float32, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Generate values from a truncated normal distribution.
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/optional.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
#include <nanobind/stl/variant.h>
|
||||
|
||||
#include "mlx/stream.h"
|
||||
@@ -56,8 +57,8 @@ void init_stream(nb::module_& m) {
|
||||
os << s;
|
||||
return os.str();
|
||||
})
|
||||
.def("__eq__", [](const Stream& s1, const Stream& s2) {
|
||||
return s1 == s2;
|
||||
.def("__eq__", [](const Stream& s, const nb::object& other) {
|
||||
return nb::isinstance<Stream>(other) && s == nb::cast<Stream>(other);
|
||||
});
|
||||
|
||||
nb::implicitly_convertible<Device::DeviceType, Device>();
|
||||
|
||||
@@ -178,7 +178,7 @@ auto py_value_and_grad(
|
||||
msg << error_msg_tag << " The return value of the function "
|
||||
<< "whose gradient we want to compute should be either a "
|
||||
<< "scalar array or a tuple with the first value being a "
|
||||
<< "scalar array (Union[array, Tuple[array, Any, ...]]); but "
|
||||
<< "scalar array (Union[array, tuple[array, Any, ...]]); but "
|
||||
<< type_name_str(py_value_out) << " was returned.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
@@ -197,7 +197,7 @@ auto py_value_and_grad(
|
||||
msg << error_msg_tag << " The return value of the function "
|
||||
<< "whose gradient we want to compute should be either a "
|
||||
<< "scalar array or a tuple with the first value being a "
|
||||
<< "scalar array (Union[array, Tuple[array, Any, ...]]); but it "
|
||||
<< "scalar array (Union[array, tuple[array, Any, ...]]); but it "
|
||||
<< "was a tuple with the first value being of type "
|
||||
<< type_name_str(ret[0]) << " .";
|
||||
throw std::invalid_argument(msg.str());
|
||||
@@ -973,13 +973,13 @@ void init_transforms(nb::module_& m) {
|
||||
.def(
|
||||
nb::init<nb::callable>(),
|
||||
"f"_a,
|
||||
nb::sig("def __init__(self, f: callable)"))
|
||||
nb::sig("def __init__(self, f: Callable)"))
|
||||
.def("__call__", &PyCustomFunction::call_impl)
|
||||
.def(
|
||||
"vjp",
|
||||
&PyCustomFunction::set_vjp,
|
||||
"f"_a,
|
||||
nb::sig("def vjp(self, f_vjp: callable)"),
|
||||
nb::sig("def vjp(self, f: Callable)"),
|
||||
R"pbdoc(
|
||||
Define a custom vjp for the wrapped function.
|
||||
|
||||
@@ -1001,7 +1001,7 @@ void init_transforms(nb::module_& m) {
|
||||
"jvp",
|
||||
&PyCustomFunction::set_jvp,
|
||||
"f"_a,
|
||||
nb::sig("def jvp(self, f_jvp: callable)"),
|
||||
nb::sig("def jvp(self, f: Callable)"),
|
||||
R"pbdoc(
|
||||
Define a custom jvp for the wrapped function.
|
||||
|
||||
@@ -1021,7 +1021,7 @@ void init_transforms(nb::module_& m) {
|
||||
"vmap",
|
||||
&PyCustomFunction::set_vmap,
|
||||
"f"_a,
|
||||
nb::sig("def vmap(self, f_vmap: callable)"),
|
||||
nb::sig("def vmap(self, f: Callable)"),
|
||||
R"pbdoc(
|
||||
Define a custom vectorization transformation for the wrapped function.
|
||||
|
||||
@@ -1116,7 +1116,7 @@ void init_transforms(nb::module_& m) {
|
||||
"primals"_a,
|
||||
"tangents"_a,
|
||||
nb::sig(
|
||||
"def jvp(fun: callable, primals: List[array], tangents: List[array]) -> Tuple[List[array], List[array]]"),
|
||||
"def jvp(fun: Callable, primals: list[array], tangents: list[array]) -> tuple[list[array], list[array]]"),
|
||||
R"pbdoc(
|
||||
Compute the Jacobian-vector product.
|
||||
|
||||
@@ -1124,7 +1124,7 @@ void init_transforms(nb::module_& m) {
|
||||
at ``primals`` with the ``tangents``.
|
||||
|
||||
Args:
|
||||
fun (callable): A function which takes a variable number of :class:`array`
|
||||
fun (Callable): A function which takes a variable number of :class:`array`
|
||||
and returns a single :class:`array` or list of :class:`array`.
|
||||
primals (list(array)): A list of :class:`array` at which to
|
||||
evaluate the Jacobian.
|
||||
@@ -1155,7 +1155,7 @@ void init_transforms(nb::module_& m) {
|
||||
"primals"_a,
|
||||
"cotangents"_a,
|
||||
nb::sig(
|
||||
"def vjp(fun: callable, primals: List[array], cotangents: List[array]) -> Tuple[List[array], List[array]]"),
|
||||
"def vjp(fun: Callable, primals: list[array], cotangents: list[array]) -> tuple[list[array], list[array]]"),
|
||||
R"pbdoc(
|
||||
Compute the vector-Jacobian product.
|
||||
|
||||
@@ -1163,7 +1163,7 @@ void init_transforms(nb::module_& m) {
|
||||
function ``fun`` evaluated at ``primals``.
|
||||
|
||||
Args:
|
||||
fun (callable): A function which takes a variable number of :class:`array`
|
||||
fun (Callable): A function which takes a variable number of :class:`array`
|
||||
and returns a single :class:`array` or list of :class:`array`.
|
||||
primals (list(array)): A list of :class:`array` at which to
|
||||
evaluate the Jacobian.
|
||||
@@ -1189,7 +1189,7 @@ void init_transforms(nb::module_& m) {
|
||||
"argnums"_a = nb::none(),
|
||||
"argnames"_a = std::vector<std::string>{},
|
||||
nb::sig(
|
||||
"def value_and_grad(fun: callable, argnums: Optional[Union[int, List[int]]] = None, argnames: Union[str, List[str]] = []) -> callable"),
|
||||
"def value_and_grad(fun: Callable, argnums: Optional[Union[int, list[int]]] = None, argnames: Union[str, list[str]] = []) -> Callable"),
|
||||
R"pbdoc(
|
||||
Returns a function which computes the value and gradient of ``fun``.
|
||||
|
||||
@@ -1221,7 +1221,7 @@ void init_transforms(nb::module_& m) {
|
||||
(loss, mse, l1), grads = mx.value_and_grad(lasso)(params, inputs, targets)
|
||||
|
||||
Args:
|
||||
fun (callable): A function which takes a variable number of
|
||||
fun (Callable): A function which takes a variable number of
|
||||
:class:`array` or trees of :class:`array` and returns
|
||||
a scalar output :class:`array` or a tuple the first element
|
||||
of which should be a scalar :class:`array`.
|
||||
@@ -1235,7 +1235,7 @@ void init_transforms(nb::module_& m) {
|
||||
no gradients for keyword arguments by default.
|
||||
|
||||
Returns:
|
||||
callable: A function which returns a tuple where the first element
|
||||
Callable: A function which returns a tuple where the first element
|
||||
is the output of `fun` and the second element is the gradients w.r.t.
|
||||
the loss.
|
||||
)pbdoc");
|
||||
@@ -1257,12 +1257,12 @@ void init_transforms(nb::module_& m) {
|
||||
"argnums"_a = nb::none(),
|
||||
"argnames"_a = std::vector<std::string>{},
|
||||
nb::sig(
|
||||
"def grad(fun: callable, argnums: Optional[Union[int, List[int]]] = None, argnames: Union[str, List[str]] = []) -> callable"),
|
||||
"def grad(fun: Callable, argnums: Optional[Union[int, list[int]]] = None, argnames: Union[str, list[str]] = []) -> Callable"),
|
||||
R"pbdoc(
|
||||
Returns a function which computes the gradient of ``fun``.
|
||||
|
||||
Args:
|
||||
fun (callable): A function which takes a variable number of
|
||||
fun (Callable): A function which takes a variable number of
|
||||
:class:`array` or trees of :class:`array` and returns
|
||||
a scalar output :class:`array`.
|
||||
argnums (int or list(int), optional): Specify the index (or indices)
|
||||
@@ -1275,7 +1275,7 @@ void init_transforms(nb::module_& m) {
|
||||
no gradients for keyword arguments by default.
|
||||
|
||||
Returns:
|
||||
callable: A function which has the same input arguments as ``fun`` and
|
||||
Callable: A function which has the same input arguments as ``fun`` and
|
||||
returns the gradient(s).
|
||||
)pbdoc");
|
||||
m.def(
|
||||
@@ -1289,12 +1289,12 @@ void init_transforms(nb::module_& m) {
|
||||
"in_axes"_a = 0,
|
||||
"out_axes"_a = 0,
|
||||
nb::sig(
|
||||
"def vmap(fun: callable, in_axes: object = 0, out_axes: object = 0) -> callable"),
|
||||
"def vmap(fun: Callable, in_axes: object = 0, out_axes: object = 0) -> Callable"),
|
||||
R"pbdoc(
|
||||
Returns a vectorized version of ``fun``.
|
||||
|
||||
Args:
|
||||
fun (callable): A function which takes a variable number of
|
||||
fun (Callable): A function which takes a variable number of
|
||||
:class:`array` or a tree of :class:`array` and returns
|
||||
a variable number of :class:`array` or a tree of :class:`array`.
|
||||
in_axes (int, optional): An integer or a valid prefix tree of the
|
||||
@@ -1307,7 +1307,7 @@ void init_transforms(nb::module_& m) {
|
||||
Defaults to ``0``.
|
||||
|
||||
Returns:
|
||||
callable: The vectorized function.
|
||||
Callable: The vectorized function.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"export_to_dot",
|
||||
@@ -1367,11 +1367,13 @@ void init_transforms(nb::module_& m) {
|
||||
"inputs"_a = nb::none(),
|
||||
"outputs"_a = nb::none(),
|
||||
"shapeless"_a = false,
|
||||
nb::sig(
|
||||
"def compile(fun: Callable, inputs: Optional[object] = None, outputs: Optional[object] = None, shapeless: bool = False) -> Callable"),
|
||||
R"pbdoc(
|
||||
Returns a compiled function which produces the same output as ``fun``.
|
||||
|
||||
Args:
|
||||
fun (callable): A function which takes a variable number of
|
||||
fun (Callable): A function which takes a variable number of
|
||||
:class:`array` or trees of :class:`array` and returns
|
||||
a variable number of :class:`array` or trees of :class:`array`.
|
||||
inputs (list or dict, optional): These inputs will be captured during
|
||||
@@ -1392,7 +1394,7 @@ void init_transforms(nb::module_& m) {
|
||||
``shapeless`` set to ``True``. Default: ``False``
|
||||
|
||||
Returns:
|
||||
callable: A compiled function which has the same input arguments
|
||||
Callable: A compiled function which has the same input arguments
|
||||
as ``fun`` and returns the the same output(s).
|
||||
)pbdoc");
|
||||
m.def(
|
||||
|
||||
Reference in New Issue
Block a user