typing: add type hints to mlx.core.array, linalg, distributed, and random (#2565)

* Add type annotations to mlx methods

* Missing list_or_scalar
This commit is contained in:
XXXXRT666
2025-09-05 00:08:11 +08:00
committed by GitHub
parent 89a3df9014
commit 8f163a367d
4 changed files with 10 additions and 2 deletions

View File

@@ -9,6 +9,7 @@ mlx.core.__prefix__:
mlx.core.__suffix__: mlx.core.__suffix__:
from typing import Union from typing import Union
scalar: TypeAlias = Union[int, float, bool] scalar: TypeAlias = Union[int, float, bool]
list_or_scalar: TypeAlias = Union[scalar, list["list_or_scalar"]]
bool_: Dtype = ... bool_: Dtype = ...
mlx.core.distributed.__prefix__: mlx.core.distributed.__prefix__:
@@ -29,5 +30,5 @@ mlx.core.metal.__prefix__:
from typing import Sequence, Optional, Union from typing import Sequence, Optional, Union
mlx.core.random.__prefix__: mlx.core.random.__prefix__:
from mlx.core import array, Dtype, Device, Stream, scalar from mlx.core import array, Dtype, Device, Stream, scalar, float32, int32
from typing import Sequence, Optional, Union from typing import Sequence, Optional, Union

View File

@@ -320,6 +320,7 @@ void init_array(nb::module_& m) {
.def_prop_ro( .def_prop_ro(
"shape", "shape",
[](const mx::array& a) { return nb::cast(a.shape()); }, [](const mx::array& a) { return nb::cast(a.shape()); },
nb::sig("def shape(self) -> tuple[int, ...]"),
R"pbdoc( R"pbdoc(
The shape of the array as a Python tuple. The shape of the array as a Python tuple.
@@ -347,6 +348,7 @@ void init_array(nb::module_& m) {
.def( .def(
"item", "item",
&to_scalar, &to_scalar,
nb::sig("def item(self) -> scalar"),
R"pbdoc( R"pbdoc(
Access the value of a scalar array. Access the value of a scalar array.
@@ -356,6 +358,7 @@ void init_array(nb::module_& m) {
.def( .def(
"tolist", "tolist",
&tolist, &tolist,
nb::sig("def tolist(self) -> list_or_scalar"),
R"pbdoc( R"pbdoc(
Convert the array to a Python :class:`list`. Convert the array to a Python :class:`list`.

View File

@@ -447,6 +447,8 @@ void init_linalg(nb::module_& parent_module) {
"a"_a, "a"_a,
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig(
"def eig(a: array, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array]"),
R"pbdoc( R"pbdoc(
Compute the eigenvalues and eigenvectors of a square matrix. Compute the eigenvalues and eigenvectors of a square matrix.
@@ -523,6 +525,8 @@ void init_linalg(nb::module_& parent_module) {
"UPLO"_a = "L", "UPLO"_a = "L",
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig(
"def eigh(a: array, UPLO: str = 'L', *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array]"),
R"pbdoc( R"pbdoc(
Compute the eigenvalues and eigenvectors of a complex Hermitian or Compute the eigenvalues and eigenvectors of a complex Hermitian or
real symmetric matrix. real symmetric matrix.

View File

@@ -171,7 +171,7 @@ void init_random(nb::module_& parent_module) {
"key"_a = nb::none(), "key"_a = nb::none(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def normal(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, loc: Optional[scalar, array] = None, scale: Optional[scalar, array] = None, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"), "def normal(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, loc: Union[scalar, array, None] = None, scale: Union[scalar, array, None] = None, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Generate normally distributed random numbers. Generate normally distributed random numbers.