diff --git a/python/mlx/_stub_patterns.txt b/python/mlx/_stub_patterns.txt new file mode 100644 index 000000000..7cc6826ed --- /dev/null +++ b/python/mlx/_stub_patterns.txt @@ -0,0 +1,20 @@ +mlx.core.distributed.__prefix__: + from mlx.core import array, Dtype, Device, Stream + from mlx.core.distributed import Group + from typing import Sequence, Optional, Union + +mlx.core.fast.__prefix__: + from mlx.core import array, Dtype, Device, Stream + from typing import Sequence, Optional, Union + +mlx.core.linalg.__prefix__: + from mlx.core import array, Dtype, Device, Stream + from typing import Sequence, Optional, Tuple, Union + +mlx.core.metal.__prefix__: + from mlx.core import array, Dtype, Device, Stream + from typing import Sequence, Optional, Union + +mlx.core.random.__prefix__: + from mlx.core import array, Dtype, Device, Stream + from typing import Sequence, Optional, Union diff --git a/python/src/linalg.cpp b/python/src/linalg.cpp index 65dd8d0e4..13d61e980 100644 --- a/python/src/linalg.cpp +++ b/python/src/linalg.cpp @@ -187,7 +187,7 @@ void init_linalg(nb::module_& parent_module) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def qr(a: array, *, stream: Union[None, Stream, Device] = None) -> tuple(array, array)"), + "def qr(a: array, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array]"), R"pbdoc( The QR factorization of the input matrix. @@ -220,7 +220,7 @@ void init_linalg(nb::module_& parent_module) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def svd(a: array, *, stream: Union[None, Stream, Device] = None) -> tuple(array, array, array)"), + "def svd(a: array, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array, array]"), R"pbdoc( The Singular Value Decomposition (SVD) of the input matrix. diff --git a/setup.py b/setup.py index 85740bf40..d4ac4a954 100644 --- a/setup.py +++ b/setup.py @@ -141,6 +141,8 @@ class GenerateStubs(Command): "nanobind.stubgen", "-m", "mlx.core", + "-p", + "python/mlx/_stub_patterns.txt", ] subprocess.run(stub_cmd + ["-r", "-O", out_path]) # Run again without recursive to specify output file name