fix submodule stubs (#1492)

This commit is contained in:
Awni Hannun 2024-10-15 16:23:37 -07:00 committed by GitHub
parent 3f86399922
commit f9f8c167d4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 24 additions and 2 deletions

View File

@ -0,0 +1,20 @@
mlx.core.distributed.__prefix__:
from mlx.core import array, Dtype, Device, Stream
from mlx.core.distributed import Group
from typing import Sequence, Optional, Union
mlx.core.fast.__prefix__:
from mlx.core import array, Dtype, Device, Stream
from typing import Sequence, Optional, Union
mlx.core.linalg.__prefix__:
from mlx.core import array, Dtype, Device, Stream
from typing import Sequence, Optional, Tuple, Union
mlx.core.metal.__prefix__:
from mlx.core import array, Dtype, Device, Stream
from typing import Sequence, Optional, Union
mlx.core.random.__prefix__:
from mlx.core import array, Dtype, Device, Stream
from typing import Sequence, Optional, Union

View File

@ -187,7 +187,7 @@ void init_linalg(nb::module_& parent_module) {
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def qr(a: array, *, stream: Union[None, Stream, Device] = None) -> tuple(array, array)"),
"def qr(a: array, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array]"),
R"pbdoc(
The QR factorization of the input matrix.
@ -220,7 +220,7 @@ void init_linalg(nb::module_& parent_module) {
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def svd(a: array, *, stream: Union[None, Stream, Device] = None) -> tuple(array, array, array)"),
"def svd(a: array, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array, array]"),
R"pbdoc(
The Singular Value Decomposition (SVD) of the input matrix.

View File

@ -141,6 +141,8 @@ class GenerateStubs(Command):
"nanobind.stubgen",
"-m",
"mlx.core",
"-p",
"python/mlx/_stub_patterns.txt",
]
subprocess.run(stub_cmd + ["-r", "-O", out_path])
# Run again without recursive to specify output file name