mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
pin cuda deps
This commit is contained in:
@@ -204,7 +204,7 @@ jobs:
|
|||||||
|
|
||||||
cuda_build_and_test:
|
cuda_build_and_test:
|
||||||
machine:
|
machine:
|
||||||
image: linux-cuda-12:2023.11.1
|
image: linux-cuda-12:2025.05.1
|
||||||
resource_class: gpu.nvidia.small.gen2
|
resource_class: gpu.nvidia.small.gen2
|
||||||
steps:
|
steps:
|
||||||
- checkout
|
- checkout
|
||||||
|
|||||||
13
setup.py
13
setup.py
@@ -205,9 +205,6 @@ if __name__ == "__main__":
|
|||||||
build_macos = platform.system() == "Darwin"
|
build_macos = platform.system() == "Darwin"
|
||||||
build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "")
|
build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "")
|
||||||
|
|
||||||
install_requires = []
|
|
||||||
if build_cuda:
|
|
||||||
install_requires = ["nvidia-cublas-cu12", "nvidia-cuda-nvrtc-cu12"]
|
|
||||||
version = get_version()
|
version = get_version()
|
||||||
|
|
||||||
_setup = partial(
|
_setup = partial(
|
||||||
@@ -250,6 +247,7 @@ if __name__ == "__main__":
|
|||||||
"mlx.distributed_config = mlx.distributed_run:distributed_config",
|
"mlx.distributed_config = mlx.distributed_run:distributed_config",
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
install_requires = []
|
||||||
|
|
||||||
# Release builds for PyPi are in two stages.
|
# Release builds for PyPi are in two stages.
|
||||||
# Each stage should be run from a clean build:
|
# Each stage should be run from a clean build:
|
||||||
@@ -269,7 +267,9 @@ if __name__ == "__main__":
|
|||||||
# - Package name is back-end specific, e.g mlx-metal
|
# - Package name is back-end specific, e.g mlx-metal
|
||||||
if build_stage != 2:
|
if build_stage != 2:
|
||||||
if build_stage == 1:
|
if build_stage == 1:
|
||||||
install_requires += [f'mlx-metal=={version}; platform_system == "Darwin"']
|
install_requires.append(
|
||||||
|
f'mlx-metal=={version}; platform_system == "Darwin"'
|
||||||
|
)
|
||||||
extras["cuda"] = [f'mlx-cuda=={version}; platform_system == "Linux"']
|
extras["cuda"] = [f'mlx-cuda=={version}; platform_system == "Linux"']
|
||||||
extras["cpu"] = [f'mlx-cpu=={version}; platform_system == "Linux"']
|
extras["cpu"] = [f'mlx-cpu=={version}; platform_system == "Linux"']
|
||||||
|
|
||||||
@@ -286,9 +286,14 @@ if __name__ == "__main__":
|
|||||||
name = "mlx-metal"
|
name = "mlx-metal"
|
||||||
elif build_cuda:
|
elif build_cuda:
|
||||||
name = "mlx-cuda"
|
name = "mlx-cuda"
|
||||||
|
install_requires += [
|
||||||
|
"nvidia-cublas-cu12==12.9.*",
|
||||||
|
"nvidia-cuda-nvrtc-cu12==12.9.*",
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
name = "mlx-cpu"
|
name = "mlx-cpu"
|
||||||
_setup(
|
_setup(
|
||||||
name=name,
|
name=name,
|
||||||
packages=["mlx"],
|
packages=["mlx"],
|
||||||
|
install_requires=install_requries,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user