mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	typing: add type hints to mlx.core.array, linalg, distributed, and random (#2565)
* Add type annotations to mlx methods * Missing list_or_scalar
This commit is contained in:
		| @@ -9,6 +9,7 @@ mlx.core.__prefix__: | ||||
| mlx.core.__suffix__: | ||||
|   from typing import Union | ||||
|   scalar: TypeAlias = Union[int, float, bool] | ||||
|   list_or_scalar: TypeAlias = Union[scalar, list["list_or_scalar"]] | ||||
|   bool_: Dtype = ... | ||||
|  | ||||
| mlx.core.distributed.__prefix__: | ||||
| @@ -29,5 +30,5 @@ mlx.core.metal.__prefix__: | ||||
|   from typing import Sequence, Optional, Union | ||||
|  | ||||
| mlx.core.random.__prefix__: | ||||
|   from mlx.core import array, Dtype, Device, Stream, scalar | ||||
|   from mlx.core import array, Dtype, Device, Stream, scalar, float32, int32 | ||||
|   from typing import Sequence, Optional, Union | ||||
|   | ||||
| @@ -320,6 +320,7 @@ void init_array(nb::module_& m) { | ||||
|       .def_prop_ro( | ||||
|           "shape", | ||||
|           [](const mx::array& a) { return nb::cast(a.shape()); }, | ||||
|           nb::sig("def shape(self) -> tuple[int, ...]"), | ||||
|           R"pbdoc( | ||||
|           The shape of the array as a Python tuple. | ||||
|  | ||||
| @@ -347,6 +348,7 @@ void init_array(nb::module_& m) { | ||||
|       .def( | ||||
|           "item", | ||||
|           &to_scalar, | ||||
|           nb::sig("def item(self) -> scalar"), | ||||
|           R"pbdoc( | ||||
|             Access the value of a scalar array. | ||||
|  | ||||
| @@ -356,6 +358,7 @@ void init_array(nb::module_& m) { | ||||
|       .def( | ||||
|           "tolist", | ||||
|           &tolist, | ||||
|           nb::sig("def tolist(self) -> list_or_scalar"), | ||||
|           R"pbdoc( | ||||
|             Convert the array to a Python :class:`list`. | ||||
|  | ||||
|   | ||||
| @@ -447,6 +447,8 @@ void init_linalg(nb::module_& parent_module) { | ||||
|       "a"_a, | ||||
|       nb::kw_only(), | ||||
|       "stream"_a = nb::none(), | ||||
|       nb::sig( | ||||
|           "def eig(a: array, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array]"), | ||||
|       R"pbdoc( | ||||
|         Compute the eigenvalues and eigenvectors of a square matrix. | ||||
|  | ||||
| @@ -523,6 +525,8 @@ void init_linalg(nb::module_& parent_module) { | ||||
|       "UPLO"_a = "L", | ||||
|       nb::kw_only(), | ||||
|       "stream"_a = nb::none(), | ||||
|       nb::sig( | ||||
|           "def eigh(a: array, UPLO: str = 'L', *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array]"), | ||||
|       R"pbdoc( | ||||
|         Compute the eigenvalues and eigenvectors of a complex Hermitian or | ||||
|         real symmetric matrix. | ||||
|   | ||||
| @@ -171,7 +171,7 @@ void init_random(nb::module_& parent_module) { | ||||
|       "key"_a = nb::none(), | ||||
|       "stream"_a = nb::none(), | ||||
|       nb::sig( | ||||
|           "def normal(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, loc: Optional[scalar, array] = None, scale: Optional[scalar, array] = None, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"), | ||||
|           "def normal(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, loc: Union[scalar, array, None] = None, scale: Union[scalar, array, None] = None, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"), | ||||
|       R"pbdoc( | ||||
|         Generate normally distributed random numbers. | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 XXXXRT666
					XXXXRT666