mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-28 00:58:15 +08:00
typing: add type hints to mlx.core.array, linalg, distributed, and random (#2565)
* Add type annotations to mlx methods * Missing list_or_scalar
This commit is contained in:
@@ -320,6 +320,7 @@ void init_array(nb::module_& m) {
|
||||
.def_prop_ro(
|
||||
"shape",
|
||||
[](const mx::array& a) { return nb::cast(a.shape()); },
|
||||
nb::sig("def shape(self) -> tuple[int, ...]"),
|
||||
R"pbdoc(
|
||||
The shape of the array as a Python tuple.
|
||||
|
||||
@@ -347,6 +348,7 @@ void init_array(nb::module_& m) {
|
||||
.def(
|
||||
"item",
|
||||
&to_scalar,
|
||||
nb::sig("def item(self) -> scalar"),
|
||||
R"pbdoc(
|
||||
Access the value of a scalar array.
|
||||
|
||||
@@ -356,6 +358,7 @@ void init_array(nb::module_& m) {
|
||||
.def(
|
||||
"tolist",
|
||||
&tolist,
|
||||
nb::sig("def tolist(self) -> list_or_scalar"),
|
||||
R"pbdoc(
|
||||
Convert the array to a Python :class:`list`.
|
||||
|
||||
|
@@ -447,6 +447,8 @@ void init_linalg(nb::module_& parent_module) {
|
||||
"a"_a,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def eig(a: array, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array]"),
|
||||
R"pbdoc(
|
||||
Compute the eigenvalues and eigenvectors of a square matrix.
|
||||
|
||||
@@ -523,6 +525,8 @@ void init_linalg(nb::module_& parent_module) {
|
||||
"UPLO"_a = "L",
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def eigh(a: array, UPLO: str = 'L', *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array]"),
|
||||
R"pbdoc(
|
||||
Compute the eigenvalues and eigenvectors of a complex Hermitian or
|
||||
real symmetric matrix.
|
||||
|
@@ -171,7 +171,7 @@ void init_random(nb::module_& parent_module) {
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def normal(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, loc: Optional[scalar, array] = None, scale: Optional[scalar, array] = None, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def normal(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, loc: Union[scalar, array, None] = None, scale: Union[scalar, array, None] = None, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Generate normally distributed random numbers.
|
||||
|
||||
|
Reference in New Issue
Block a user