mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 19:38:16 +08:00
typing: add type hints to mlx.core.array, linalg, distributed, and random (#2565)
* Add type annotations to mlx methods * Missing list_or_scalar
This commit is contained in:
@@ -9,6 +9,7 @@ mlx.core.__prefix__:
|
||||
mlx.core.__suffix__:
|
||||
from typing import Union
|
||||
scalar: TypeAlias = Union[int, float, bool]
|
||||
list_or_scalar: TypeAlias = Union[scalar, list["list_or_scalar"]]
|
||||
bool_: Dtype = ...
|
||||
|
||||
mlx.core.distributed.__prefix__:
|
||||
@@ -29,5 +30,5 @@ mlx.core.metal.__prefix__:
|
||||
from typing import Sequence, Optional, Union
|
||||
|
||||
mlx.core.random.__prefix__:
|
||||
from mlx.core import array, Dtype, Device, Stream, scalar
|
||||
from mlx.core import array, Dtype, Device, Stream, scalar, float32, int32
|
||||
from typing import Sequence, Optional, Union
|
||||
|
@@ -320,6 +320,7 @@ void init_array(nb::module_& m) {
|
||||
.def_prop_ro(
|
||||
"shape",
|
||||
[](const mx::array& a) { return nb::cast(a.shape()); },
|
||||
nb::sig("def shape(self) -> tuple[int, ...]"),
|
||||
R"pbdoc(
|
||||
The shape of the array as a Python tuple.
|
||||
|
||||
@@ -347,6 +348,7 @@ void init_array(nb::module_& m) {
|
||||
.def(
|
||||
"item",
|
||||
&to_scalar,
|
||||
nb::sig("def item(self) -> scalar"),
|
||||
R"pbdoc(
|
||||
Access the value of a scalar array.
|
||||
|
||||
@@ -356,6 +358,7 @@ void init_array(nb::module_& m) {
|
||||
.def(
|
||||
"tolist",
|
||||
&tolist,
|
||||
nb::sig("def tolist(self) -> list_or_scalar"),
|
||||
R"pbdoc(
|
||||
Convert the array to a Python :class:`list`.
|
||||
|
||||
|
@@ -447,6 +447,8 @@ void init_linalg(nb::module_& parent_module) {
|
||||
"a"_a,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def eig(a: array, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array]"),
|
||||
R"pbdoc(
|
||||
Compute the eigenvalues and eigenvectors of a square matrix.
|
||||
|
||||
@@ -523,6 +525,8 @@ void init_linalg(nb::module_& parent_module) {
|
||||
"UPLO"_a = "L",
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def eigh(a: array, UPLO: str = 'L', *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array]"),
|
||||
R"pbdoc(
|
||||
Compute the eigenvalues and eigenvectors of a complex Hermitian or
|
||||
real symmetric matrix.
|
||||
|
@@ -171,7 +171,7 @@ void init_random(nb::module_& parent_module) {
|
||||
"key"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def normal(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, loc: Optional[scalar, array] = None, scale: Optional[scalar, array] = None, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def normal(shape: Sequence[int] = [], dtype: Optional[Dtype] = float32, loc: Union[scalar, array, None] = None, scale: Union[scalar, array, None] = None, key: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Generate normally distributed random numbers.
|
||||
|
||||
|
Reference in New Issue
Block a user