mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-20 03:48:15 +08:00
typing: add type hints to mlx.core.array, linalg, distributed, and random (#2565)
* Add type annotations to mlx methods * Missing list_or_scalar
This commit is contained in:
@@ -9,6 +9,7 @@ mlx.core.__prefix__:
|
|||||||
mlx.core.__suffix__:
|
mlx.core.__suffix__:
|
||||||
from typing import Union
|
from typing import Union
|
||||||
scalar: TypeAlias = Union[int, float, bool]
|
scalar: TypeAlias = Union[int, float, bool]
|
||||||
|
list_or_scalar: TypeAlias = Union[scalar, list["list_or_scalar"]]
|
||||||
bool_: Dtype = ...
|
bool_: Dtype = ...
|
||||||
|
|
||||||
mlx.core.distributed.__prefix__:
|
mlx.core.distributed.__prefix__:
|
||||||
@@ -29,5 +30,5 @@ mlx.core.metal.__prefix__:
|
|||||||
from typing import Sequence, Optional, Union
|
from typing import Sequence, Optional, Union
|
||||||
|
|
||||||
mlx.core.random.__prefix__:
|
mlx.core.random.__prefix__:
|
||||||
from mlx.core import array, Dtype, Device, Stream, scalar
|
from mlx.core import array, Dtype, Device, Stream, scalar, float32, int32
|
||||||
from typing import Sequence, Optional, Union
|
from typing import Sequence, Optional, Union
|
||||||
|
@@ -320,6 +320,7 @@ void init_array(nb::module_& m) {
|
|||||||
.def_prop_ro(
|
.def_prop_ro(
|
||||||
"shape",
|
"shape",
|
||||||
[](const mx::array& a) { return nb::cast(a.shape()); },
|
[](const mx::array& a) { return nb::cast(a.shape()); },
|
||||||
|
nb::sig("def shape(self) -> tuple[int, ...]"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
The shape of the array as a Python tuple.
|
The shape of the array as a Python tuple.
|
||||||
|
|
||||||
@@ -347,6 +348,7 @@ void init_array(nb::module_& m) {
|
|||||||
.def(
|
.def(
|
||||||
"item",
|
"item",
|
||||||
&to_scalar,
|
&to_scalar,
|
||||||
|
nb::sig("def item(self) -> scalar"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Access the value of a scalar array.
|
Access the value of a scalar array.
|
||||||
|
|
||||||
@@ -356,6 +358,7 @@ void init_array(nb::module_& m) {
|
|||||||
.def(
|
.def(
|
||||||
"tolist",
|
"tolist",
|
||||||
&tolist,
|
&tolist,
|
||||||
|
nb::sig("def tolist(self) -> list_or_scalar"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Convert the array to a Python :class:`list`.
|
Convert the array to a Python :class:`list`.
|
||||||
|
|
||||||
|
@@ -447,6 +447,8 @@ void init_linalg(nb::module_& parent_module) {
|
|||||||
"a"_a,
|
"a"_a,
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
|
nb::sig(
|
||||||
|
"def eig(a: array, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array]"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Compute the eigenvalues and eigenvectors of a square matrix.
|
Compute the eigenvalues and eigenvectors of a square matrix.
|
||||||
|
|
||||||
@@ -523,6 +525,8 @@ void init_linalg(nb::module_& parent_module) {
|
|||||||
"UPLO"_a = "L",
|
"UPLO"_a = "L",
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
|
nb::sig(
|
||||||
|
"def eigh(a: array, UPLO: str = 'L', *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array]"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Compute the eigenvalues and eigenvectors of a complex Hermitian or
|
Compute the eigenvalues and eigenvectors of a complex Hermitian or
|
||||||
real symmetric matrix.
|
real symmetric matrix.
|
||||||
|
@@ -171,7 +171,7 @@ void init_random(nb::module_& parent_module) {
|
|||||||
"key"_a = nb::none(),
|
"key"_a = nb::none(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
"def normal(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, loc: Optional[scalar, array] = None, scale: Optional[scalar, array] = None, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
"def normal(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, loc: Union[scalar, array, None] = None, scale: Union[scalar, array, None] = None, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Generate normally distributed random numbers.
|
Generate normally distributed random numbers.
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user