diff --git a/python/mlx/_stub_patterns.txt b/python/mlx/_stub_patterns.txt index 7cc6826ed..2c10dafd1 100644 --- a/python/mlx/_stub_patterns.txt +++ b/python/mlx/_stub_patterns.txt @@ -1,20 +1,33 @@ +mlx.core.__prefix__: + from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + import sys + if sys.version_info >= (3, 10): + from typing import TypeAlias + else: + from typing_extensions import TypeAlias + +mlx.core.__suffix__: + from typing import Union + scalar: TypeAlias = Union[int, float, bool] + bool_: Dtype = ... + mlx.core.distributed.__prefix__: - from mlx.core import array, Dtype, Device, Stream + from mlx.core import array, Dtype, Device, Stream, scalar from mlx.core.distributed import Group from typing import Sequence, Optional, Union mlx.core.fast.__prefix__: - from mlx.core import array, Dtype, Device, Stream + from mlx.core import array, Dtype, Device, Stream, scalar from typing import Sequence, Optional, Union mlx.core.linalg.__prefix__: - from mlx.core import array, Dtype, Device, Stream + from mlx.core import array, Dtype, Device, Stream, scalar from typing import Sequence, Optional, Tuple, Union mlx.core.metal.__prefix__: - from mlx.core import array, Dtype, Device, Stream + from mlx.core import array, Dtype, Device, Stream, scalar from typing import Sequence, Optional, Union mlx.core.random.__prefix__: - from mlx.core import array, Dtype, Device, Stream + from mlx.core import array, Dtype, Device, Stream, scalar from typing import Sequence, Optional, Union diff --git a/setup.py b/setup.py index 72646e9ad..b0636a70f 100644 --- a/setup.py +++ b/setup.py @@ -176,10 +176,6 @@ class GenerateStubs(Command): # Run again without recursive to specify output file name subprocess.run(["rm", f"{out_path}/mlx.pyi"]) subprocess.run(stub_cmd + ["-o", f"{out_path}/__init__.pyi"]) - # mx.bool_ gets filtered by nanobind because of the trailing - # underscore, add it manually: - with open(f"{out_path}/__init__.pyi", "a") as fid: - fid.write("\nbool_: Dtype = ...") class MLXBdistWheel(bdist_wheel):