typing: add type hints to mlx.core.array, linalg, distributed, and random (#2565)

* Add type annotations to mlx methods

* Missing list_or_scalar
This commit is contained in:
XXXXRT666
2025-09-05 00:08:11 +08:00
committed by GitHub
parent 89a3df9014
commit 8f163a367d
4 changed files with 10 additions and 2 deletions

View File

@@ -9,6 +9,7 @@ mlx.core.__prefix__:
mlx.core.__suffix__:
from typing import Union
scalar: TypeAlias = Union[int, float, bool]
list_or_scalar: TypeAlias = Union[scalar, list["list_or_scalar"]]
bool_: Dtype = ...
mlx.core.distributed.__prefix__:
@@ -29,5 +30,5 @@ mlx.core.metal.__prefix__:
from typing import Sequence, Optional, Union
mlx.core.random.__prefix__:
from mlx.core import array, Dtype, Device, Stream, scalar
from mlx.core import array, Dtype, Device, Stream, scalar, float32, int32
from typing import Sequence, Optional, Union

View File

@@ -320,6 +320,7 @@ void init_array(nb::module_& m) {
.def_prop_ro(
"shape",
[](const mx::array& a) { return nb::cast(a.shape()); },
nb::sig("def shape(self) -> tuple[int, ...]"),
R"pbdoc(
The shape of the array as a Python tuple.
@@ -347,6 +348,7 @@ void init_array(nb::module_& m) {
.def(
"item",
&to_scalar,
nb::sig("def item(self) -> scalar"),
R"pbdoc(
Access the value of a scalar array.
@@ -356,6 +358,7 @@ void init_array(nb::module_& m) {
.def(
"tolist",
&tolist,
nb::sig("def tolist(self) -> list_or_scalar"),
R"pbdoc(
Convert the array to a Python :class:`list`.

View File

@@ -447,6 +447,8 @@ void init_linalg(nb::module_& parent_module) {
"a"_a,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def eig(a: array, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array]"),
R"pbdoc(
Compute the eigenvalues and eigenvectors of a square matrix.
@@ -523,6 +525,8 @@ void init_linalg(nb::module_& parent_module) {
"UPLO"_a = "L",
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def eigh(a: array, UPLO: str = 'L', *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array]"),
R"pbdoc(
Compute the eigenvalues and eigenvectors of a complex Hermitian or
real symmetric matrix.

View File

@@ -171,7 +171,7 @@ void init_random(nb::module_& parent_module) {
"key"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def normal(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, loc: Optional[scalar, array] = None, scale: Optional[scalar, array] = None, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
"def normal(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, loc: Union[scalar, array, None] = None, scale: Union[scalar, array, None] = None, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Generate normally distributed random numbers.