mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Fix including stubs in wheel (#2398)
* fix including stubs in wheel * fix bool_
This commit is contained in:
		
							
								
								
									
										27
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										27
									
								
								setup.py
									
									
									
									
									
								
							| @@ -9,7 +9,7 @@ from functools import partial | |||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from subprocess import run | from subprocess import run | ||||||
|  |  | ||||||
| from setuptools import Command, Extension, setup | from setuptools import Command, Extension, find_namespace_packages, setup | ||||||
| from setuptools.command.bdist_wheel import bdist_wheel | from setuptools.command.bdist_wheel import bdist_wheel | ||||||
| from setuptools.command.build_ext import build_ext | from setuptools.command.build_ext import build_ext | ||||||
|  |  | ||||||
| @@ -166,6 +166,10 @@ class GenerateStubs(Command): | |||||||
|         # Run again without recursive to specify output file name |         # Run again without recursive to specify output file name | ||||||
|         subprocess.run(["rm", f"{out_path}/mlx.pyi"]) |         subprocess.run(["rm", f"{out_path}/mlx.pyi"]) | ||||||
|         subprocess.run(stub_cmd + ["-o", f"{out_path}/__init__.pyi"]) |         subprocess.run(stub_cmd + ["-o", f"{out_path}/__init__.pyi"]) | ||||||
|  |         # mx.bool_ gets filtered by nanobind because of the trailing | ||||||
|  |         # underscore, add it manually: | ||||||
|  |         with open(f"{out_path}/__init__.pyi", "a") as fid: | ||||||
|  |             fid.write("\nbool_: Dtype = ...") | ||||||
|  |  | ||||||
|  |  | ||||||
| class MLXBdistWheel(bdist_wheel): | class MLXBdistWheel(bdist_wheel): | ||||||
| @@ -184,12 +188,19 @@ with open(Path(__file__).parent / "README.md", encoding="utf-8") as f: | |||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     package_dir = {"": "python"} |     package_dir = {"": "python"} | ||||||
|     packages = [ |     packages = find_namespace_packages( | ||||||
|         "mlx", |         where="python", | ||||||
|         "mlx.nn", |         exclude=[ | ||||||
|         "mlx.nn.layers", |             "src", | ||||||
|         "mlx.optimizers", |             "tests", | ||||||
|     ] |             "scripts", | ||||||
|  |             "mlx.lib", | ||||||
|  |             "mlx.include", | ||||||
|  |             "mlx.share", | ||||||
|  |             "mlx.share.**", | ||||||
|  |             "mlx.include.**", | ||||||
|  |         ], | ||||||
|  |     ) | ||||||
|  |  | ||||||
|     build_macos = platform.system() == "Darwin" |     build_macos = platform.system() == "Darwin" | ||||||
|     build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "") |     build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "") | ||||||
| @@ -221,7 +232,7 @@ if __name__ == "__main__": | |||||||
|         }, |         }, | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     package_data = {"mlx": ["lib/*", "include/*", "share/*"], "mlx.core": ["*.pyi"]} |     package_data = {"mlx.core": ["*.pyi"]} | ||||||
|  |  | ||||||
|     extras = { |     extras = { | ||||||
|         "dev": [ |         "dev": [ | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun