Fix including stubs in wheel (#2398)

* fix including stubs in wheel

* fix bool_
This commit is contained in:
Awni Hannun 2025-07-22 06:30:17 -07:00 committed by GitHub
parent 56cc858af9
commit 08638223ca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -9,7 +9,7 @@ from functools import partial
from pathlib import Path
from subprocess import run
from setuptools import Command, Extension, setup
from setuptools import Command, Extension, find_namespace_packages, setup
from setuptools.command.bdist_wheel import bdist_wheel
from setuptools.command.build_ext import build_ext
@ -166,6 +166,10 @@ class GenerateStubs(Command):
# Run again without recursive to specify output file name
subprocess.run(["rm", f"{out_path}/mlx.pyi"])
subprocess.run(stub_cmd + ["-o", f"{out_path}/__init__.pyi"])
# mx.bool_ gets filtered by nanobind because of the trailing
# underscore, add it manually:
with open(f"{out_path}/__init__.pyi", "a") as fid:
fid.write("\nbool_: Dtype = ...")
class MLXBdistWheel(bdist_wheel):
@ -184,12 +188,19 @@ with open(Path(__file__).parent / "README.md", encoding="utf-8") as f:
if __name__ == "__main__":
package_dir = {"": "python"}
packages = [
"mlx",
"mlx.nn",
"mlx.nn.layers",
"mlx.optimizers",
]
packages = find_namespace_packages(
where="python",
exclude=[
"src",
"tests",
"scripts",
"mlx.lib",
"mlx.include",
"mlx.share",
"mlx.share.**",
"mlx.include.**",
],
)
build_macos = platform.system() == "Darwin"
build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "")
@ -221,7 +232,7 @@ if __name__ == "__main__":
},
)
package_data = {"mlx": ["lib/*", "include/*", "share/*"], "mlx.core": ["*.pyi"]}
package_data = {"mlx.core": ["*.pyi"]}
extras = {
"dev": [