From f9f8c167d43085391cf210e2214835c1695f04fa Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 15 Oct 2024 16:23:37 -0700 Subject: [PATCH] fix submodule stubs (#1492) --- python/mlx/_stub_patterns.txt | 20 ++++++++++++++++++++ python/src/linalg.cpp | 4 ++-- setup.py | 2 ++ 3 files changed, 24 insertions(+), 2 deletions(-) create mode 100644 python/mlx/_stub_patterns.txt diff --git a/python/mlx/_stub_patterns.txt b/python/mlx/_stub_patterns.txt new file mode 100644 index 000000000..7cc6826ed --- /dev/null +++ b/python/mlx/_stub_patterns.txt @@ -0,0 +1,20 @@ +mlx.core.distributed.__prefix__: + from mlx.core import array, Dtype, Device, Stream + from mlx.core.distributed import Group + from typing import Sequence, Optional, Union + +mlx.core.fast.__prefix__: + from mlx.core import array, Dtype, Device, Stream + from typing import Sequence, Optional, Union + +mlx.core.linalg.__prefix__: + from mlx.core import array, Dtype, Device, Stream + from typing import Sequence, Optional, Tuple, Union + +mlx.core.metal.__prefix__: + from mlx.core import array, Dtype, Device, Stream + from typing import Sequence, Optional, Union + +mlx.core.random.__prefix__: + from mlx.core import array, Dtype, Device, Stream + from typing import Sequence, Optional, Union diff --git a/python/src/linalg.cpp b/python/src/linalg.cpp index 65dd8d0e4..13d61e980 100644 --- a/python/src/linalg.cpp +++ b/python/src/linalg.cpp @@ -187,7 +187,7 @@ void init_linalg(nb::module_& parent_module) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def qr(a: array, *, stream: Union[None, Stream, Device] = None) -> tuple(array, array)"), + "def qr(a: array, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array]"), R"pbdoc( The QR factorization of the input matrix. @@ -220,7 +220,7 @@ void init_linalg(nb::module_& parent_module) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def svd(a: array, *, stream: Union[None, Stream, Device] = None) -> tuple(array, array, array)"), + "def svd(a: array, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array, array]"), R"pbdoc( The Singular Value Decomposition (SVD) of the input matrix. diff --git a/setup.py b/setup.py index 85740bf40..d4ac4a954 100644 --- a/setup.py +++ b/setup.py @@ -141,6 +141,8 @@ class GenerateStubs(Command): "nanobind.stubgen", "-m", "mlx.core", + "-p", + "python/mlx/_stub_patterns.txt", ] subprocess.run(stub_cmd + ["-r", "-O", out_path]) # Run again without recursive to specify output file name