pin cuda deps

This commit is contained in:
Awni Hannun
2025-07-23 10:00:11 -07:00
parent 3b9c665cb8
commit 0d8a7d8248
2 changed files with 10 additions and 5 deletions

View File

@@ -204,7 +204,7 @@ jobs:
cuda_build_and_test: cuda_build_and_test:
machine: machine:
image: linux-cuda-12:2023.11.1 image: linux-cuda-12:2025.05.1
resource_class: gpu.nvidia.small.gen2 resource_class: gpu.nvidia.small.gen2
steps: steps:
- checkout - checkout

View File

@@ -205,9 +205,6 @@ if __name__ == "__main__":
build_macos = platform.system() == "Darwin" build_macos = platform.system() == "Darwin"
build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "") build_cuda = "MLX_BUILD_CUDA=ON" in os.environ.get("CMAKE_ARGS", "")
install_requires = []
if build_cuda:
install_requires = ["nvidia-cublas-cu12", "nvidia-cuda-nvrtc-cu12"]
version = get_version() version = get_version()
_setup = partial( _setup = partial(
@@ -250,6 +247,7 @@ if __name__ == "__main__":
"mlx.distributed_config = mlx.distributed_run:distributed_config", "mlx.distributed_config = mlx.distributed_run:distributed_config",
] ]
} }
install_requires = []
# Release builds for PyPi are in two stages. # Release builds for PyPi are in two stages.
# Each stage should be run from a clean build: # Each stage should be run from a clean build:
@@ -269,7 +267,9 @@ if __name__ == "__main__":
# - Package name is back-end specific, e.g mlx-metal # - Package name is back-end specific, e.g mlx-metal
if build_stage != 2: if build_stage != 2:
if build_stage == 1: if build_stage == 1:
install_requires += [f'mlx-metal=={version}; platform_system == "Darwin"'] install_requires.append(
f'mlx-metal=={version}; platform_system == "Darwin"'
)
extras["cuda"] = [f'mlx-cuda=={version}; platform_system == "Linux"'] extras["cuda"] = [f'mlx-cuda=={version}; platform_system == "Linux"']
extras["cpu"] = [f'mlx-cpu=={version}; platform_system == "Linux"'] extras["cpu"] = [f'mlx-cpu=={version}; platform_system == "Linux"']
@@ -286,9 +286,14 @@ if __name__ == "__main__":
name = "mlx-metal" name = "mlx-metal"
elif build_cuda: elif build_cuda:
name = "mlx-cuda" name = "mlx-cuda"
install_requires += [
"nvidia-cublas-cu12==12.9.*",
"nvidia-cuda-nvrtc-cu12==12.9.*",
]
else: else:
name = "mlx-cpu" name = "mlx-cpu"
_setup( _setup(
name=name, name=name,
packages=["mlx"], packages=["mlx"],
install_requires=install_requries,
) )