Fix including stubs in wheel (#2398)

* fix including stubs in wheel

* fix bool_
This commit is contained in:
Awni Hannun 2025-07-22 06:30:17 -07:00 committed by GitHub
parent 56cc858af9
commit 08638223ca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -9,7 +9,7 @@ from functools import partial
from pathlib import Path from pathlib import Path
from subprocess import run from subprocess import run
from setuptools import Command, Extension, setup from setuptools import Command, Extension, find_namespace_packages, setup
from setuptools.command.bdist_wheel import bdist_wheel from setuptools.command.bdist_wheel import bdist_wheel
from setuptools.command.build_ext import build_ext from setuptools.command.build_ext import build_ext
@ -166,6 +166,10 @@ class GenerateStubs(Command):
# Run again without recursive to specify output file name # Run again without recursive to specify output file name
subprocess.run(["rm", f"{out_path}/mlx.pyi"]) subprocess.run(["rm", f"{out_path}/mlx.pyi"])
subprocess.run(stub_cmd + ["-o", f"{out_path}/__init__.pyi"]) subprocess.run(stub_cmd + ["-o", f"{out_path}/__init__.pyi"])
# mx.bool_ gets filtered by nanobind because of the trailing
# underscore, add it manually:
with open(f"{out_path}/__init__.pyi", "a") as fid:
fid.write("\nbool_: Dtype = ...")
class MLXBdistWheel(bdist_wheel): class MLXBdistWheel(bdist_wheel):
@ -184,12 +188,19 @@ with open(Path(__file__).parent / "README.md", encoding="utf-8") as f:
if __name__ == "__main__": if __name__ == "__main__":
package_dir = {"": "python"} package_dir = {"": "python"}
packages = [ packages = find_namespace_packages(
"mlx", where="python",
"mlx.nn", exclude=[
"mlx.nn.layers", "src",
"mlx.optimizers", "tests",
] "scripts",
"mlx.lib",
"mlx.include",
"mlx.share",
"mlx.share.**",
"mlx.include.**",
],
)
build_macos = platform.system() == "Darwin" build_macos = platform.system() == "Darwin"
build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "") build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "")
@ -221,7 +232,7 @@ if __name__ == "__main__":
}, },
) )
package_data = {"mlx": ["lib/*", "include/*", "share/*"], "mlx.core": ["*.pyi"]} package_data = {"mlx.core": ["*.pyi"]}
extras = { extras = {
"dev": [ "dev": [