JAX: add v0.4.27, NCCL variant (#44071)
This commit is contained in:
		| @@ -19,11 +19,12 @@ class PyJax(PythonPackage): | ||||
|     arbitrarily to any order.""" | ||||
| 
 | ||||
|     homepage = "https://github.com/google/jax" | ||||
|     pypi = "jax/jax-0.2.25.tar.gz" | ||||
|     pypi = "jax/jax-0.4.27.tar.gz" | ||||
| 
 | ||||
|     license("Apache-2.0") | ||||
|     maintainers("adamjstewart", "jonas-eschle") | ||||
| 
 | ||||
|     version("0.4.27", sha256="f3d7f19bdc0a17ccdb305086099a5a90c704f904d4272a70debe06ae6552998c") | ||||
|     version("0.4.26", sha256="2cce025d0a279ec630d550524749bc8efe25d2ff47240d2a7d4cfbc5090c5383") | ||||
|     version("0.4.25", sha256="a8ee189c782de2b7b2ffb64a8916da380b882a617e2769aa429b71d79747b982") | ||||
|     version("0.4.24", sha256="4a6b6fd026ddd22653c7fa2fac1904c3de2dbe845b61ede08af9a5cc709662ae") | ||||
| @@ -59,26 +60,30 @@ class PyJax(PythonPackage): | ||||
|         deprecated=True, | ||||
|     ) | ||||
| 
 | ||||
|     depends_on("python@3.9:", when="@0.4.14:", type=("build", "run")) | ||||
|     depends_on("python@3.8:", when="@0.4:", type=("build", "run")) | ||||
|     depends_on("py-setuptools", type="build") | ||||
|     depends_on("py-ml-dtypes@0.2:", when="@0.4.14:", type=("build", "run")) | ||||
|     depends_on("py-ml-dtypes@0.1:", when="@0.4.9:", type=("build", "run")) | ||||
|     depends_on("py-ml-dtypes@0.0.3:", when="@0.4.7:", type=("build", "run")) | ||||
|     depends_on("py-numpy@1.22:", when="@0.4.14:", type=("build", "run")) | ||||
|     depends_on("py-numpy@1.21:", when="@0.4.7:", type=("build", "run")) | ||||
|     depends_on("py-numpy@1.20:", when="@0.3:", type=("build", "run")) | ||||
|     depends_on("py-numpy@1.18:", type=("build", "run")) | ||||
|     depends_on("py-opt-einsum", type=("build", "run")) | ||||
|     depends_on("py-scipy@1.9:", when="@0.4.19:", type=("build", "run")) | ||||
|     depends_on("py-scipy@1.7:", when="@0.4.7:", type=("build", "run")) | ||||
|     depends_on("py-scipy@1.5:", when="@0.3:", type=("build", "run")) | ||||
|     depends_on("py-scipy@1.2.1:", type=("build", "run")) | ||||
|     depends_on("py-importlib-metadata@4.6:", when="@0.4.11: ^python@:3.9", type=("build", "run")) | ||||
| 
 | ||||
|     # See jax/_src/lib/__init__.py | ||||
|     with default_args(type=("build", "run")): | ||||
|         # setup.py | ||||
|         depends_on("python@3.9:", when="@0.4.14:") | ||||
|         depends_on("python@3.8:", when="@0.4:") | ||||
|         depends_on("py-ml-dtypes@0.2:", when="@0.4.14:") | ||||
|         depends_on("py-ml-dtypes@0.1:", when="@0.4.9:") | ||||
|         depends_on("py-ml-dtypes@0.0.3:", when="@0.4.7:") | ||||
|         depends_on("py-numpy@1.22:", when="@0.4.14:") | ||||
|         depends_on("py-numpy@1.21:", when="@0.4.7:") | ||||
|         depends_on("py-numpy@1.20:", when="@0.3:") | ||||
|         depends_on("py-numpy@1.18:") | ||||
|         depends_on("py-opt-einsum") | ||||
|         depends_on("py-scipy@1.9:", when="@0.4.19:") | ||||
|         depends_on("py-scipy@1.7:", when="@0.4.7:") | ||||
|         depends_on("py-scipy@1.5:", when="@0.3:") | ||||
|         depends_on("py-scipy@1.2.1:") | ||||
|         depends_on("py-importlib-metadata@4.6:", when="@0.4.11: ^python@:3.9") | ||||
| 
 | ||||
|         # jax/_src/lib/__init__.py | ||||
|         # https://github.com/google/jax/commit/8be057de1f50756fe7522f7e98b2f30fad56f7e4 | ||||
|         for v in [ | ||||
|             "0.4.27", | ||||
|             "0.4.26", | ||||
|             "0.4.25", | ||||
|             "0.4.24", | ||||
| @@ -105,28 +110,29 @@ class PyJax(PythonPackage): | ||||
|             "0.4.3", | ||||
|             "0.3.23", | ||||
|         ]: | ||||
|         depends_on(f"py-jaxlib@:{v}", when=f"@{v}", type=("build", "run")) | ||||
|             depends_on(f"py-jaxlib@:{v}", when=f"@{v}") | ||||
| 
 | ||||
|         # See _minimum_jaxlib_version in jax/version.py | ||||
|     depends_on("py-jaxlib@0.4.20:", when="@0.4.25:", type=("build", "run")) | ||||
|     depends_on("py-jaxlib@0.4.19:", when="@0.4.21:", type=("build", "run")) | ||||
|     depends_on("py-jaxlib@0.4.14:", when="@0.4.15:", type=("build", "run")) | ||||
|     depends_on("py-jaxlib@0.4.11:", when="@0.4.12:", type=("build", "run")) | ||||
|     depends_on("py-jaxlib@0.4.7:", when="@0.4.8:", type=("build", "run")) | ||||
|     depends_on("py-jaxlib@0.4.6:", when="@0.4.7:", type=("build", "run")) | ||||
|     depends_on("py-jaxlib@0.4.4:", when="@0.4.5:", type=("build", "run")) | ||||
|     depends_on("py-jaxlib@0.4.2:", when="@0.4.3:", type=("build", "run")) | ||||
|     depends_on("py-jaxlib@0.4.1:", when="@0.4.2:", type=("build", "run")) | ||||
|     depends_on("py-jaxlib@0.3.22:", when="@0.3.24:", type=("build", "run")) | ||||
|     depends_on("py-jaxlib@0.3.15:", when="@0.3.18:", type=("build", "run")) | ||||
|     depends_on("py-jaxlib@0.3.14:", when="@0.3.15:", type=("build", "run")) | ||||
|     depends_on("py-jaxlib@0.3.7:", when="@0.3.8:", type=("build", "run")) | ||||
|     depends_on("py-jaxlib@0.3.2:", when="@0.3.7:", type=("build", "run")) | ||||
|     depends_on("py-jaxlib@0.3.0:", when="@0.3.2:", type=("build", "run")) | ||||
|     depends_on("py-jaxlib@0.1.74:", when="@0.2.26:", type=("build", "run")) | ||||
|     depends_on("py-jaxlib@0.1.69:", when="@0.2.18:", type=("build", "run")) | ||||
|         depends_on("py-jaxlib@0.4.23:", when="@0.4.27:") | ||||
|         depends_on("py-jaxlib@0.4.20:", when="@0.4.25:") | ||||
|         depends_on("py-jaxlib@0.4.19:", when="@0.4.21:") | ||||
|         depends_on("py-jaxlib@0.4.14:", when="@0.4.15:") | ||||
|         depends_on("py-jaxlib@0.4.11:", when="@0.4.12:") | ||||
|         depends_on("py-jaxlib@0.4.7:", when="@0.4.8:") | ||||
|         depends_on("py-jaxlib@0.4.6:", when="@0.4.7:") | ||||
|         depends_on("py-jaxlib@0.4.4:", when="@0.4.5:") | ||||
|         depends_on("py-jaxlib@0.4.2:", when="@0.4.3:") | ||||
|         depends_on("py-jaxlib@0.4.1:", when="@0.4.2:") | ||||
|         depends_on("py-jaxlib@0.3.22:", when="@0.3.24:") | ||||
|         depends_on("py-jaxlib@0.3.15:", when="@0.3.18:") | ||||
|         depends_on("py-jaxlib@0.3.14:", when="@0.3.15:") | ||||
|         depends_on("py-jaxlib@0.3.7:", when="@0.3.8:") | ||||
|         depends_on("py-jaxlib@0.3.2:", when="@0.3.7:") | ||||
|         depends_on("py-jaxlib@0.3.0:", when="@0.3.2:") | ||||
|         depends_on("py-jaxlib@0.1.74:", when="@0.2.26:") | ||||
|         depends_on("py-jaxlib@0.1.69:", when="@0.2.18:") | ||||
| 
 | ||||
|         # Historical dependencies | ||||
|     depends_on("py-absl-py", when="@:0.3", type=("build", "run")) | ||||
|     depends_on("py-typing-extensions", when="@:0.3", type=("build", "run")) | ||||
|     depends_on("py-etils+epath", when="@0.3", type=("build", "run")) | ||||
|         depends_on("py-absl-py", when="@:0.3") | ||||
|         depends_on("py-typing-extensions", when="@:0.3") | ||||
|         depends_on("py-etils+epath", when="@0.3") | ||||
|   | ||||
| @@ -12,7 +12,7 @@ class PyJaxlib(PythonPackage, CudaPackage): | ||||
|     """XLA library for Jax""" | ||||
| 
 | ||||
|     homepage = "https://github.com/google/jax" | ||||
|     url = "https://github.com/google/jax/archive/refs/tags/jaxlib-v0.1.74.tar.gz" | ||||
|     url = "https://github.com/google/jax/archive/refs/tags/jaxlib-v0.4.27.tar.gz" | ||||
| 
 | ||||
|     tmp_path = "" | ||||
|     buildtmp = "" | ||||
| @@ -20,6 +20,7 @@ class PyJaxlib(PythonPackage, CudaPackage): | ||||
|     license("Apache-2.0") | ||||
|     maintainers("adamjstewart") | ||||
| 
 | ||||
|     version("0.4.27", sha256="c2c82cd9ad3b395d5cbc0affa26a2938e52677a69ca8f0b9ef9922a52cac4f0c") | ||||
|     version("0.4.26", sha256="ddc14da1eaa34f23430d40ad9b9585088575cac439a2fa1c6833a247e1b221fd") | ||||
|     version("0.4.25", sha256="fc1197c401924942eb14185a61688d0c476e3e81ff71f9dc95e620b57c06eec8") | ||||
|     version("0.4.24", sha256="c4e6963c2c36f634a9a1765e476a1ed4e6c4a7954465ebf72e29f344c28ddc28") | ||||
| @@ -45,52 +46,63 @@ class PyJaxlib(PythonPackage, CudaPackage): | ||||
|         deprecated=True, | ||||
|     ) | ||||
| 
 | ||||
|     variant("cuda", default=True, description="Build with CUDA") | ||||
|     variant("cuda", default=True, description="Build with CUDA enabled") | ||||
|     variant("nccl", default=True, description="Build with NCCL enabled", when="+cuda") | ||||
| 
 | ||||
|     # docs/installation.md | ||||
|     # jaxlib/setup.py | ||||
|     with when("+cuda"): | ||||
|         depends_on("cuda@12.1:", when="@0.4.26:") | ||||
|         depends_on("cuda@11.8:", when="@0.4.11:") | ||||
|         depends_on("cuda@11.4:", when="@0.4.0:0.4.7") | ||||
|         depends_on("cuda@11.1:", when="@0.3") | ||||
|         depends_on("cuda@11.1:11.7.0", when="@0.1") | ||||
|         depends_on("cudnn@8.9:8", when="@0.4.26:") | ||||
|         depends_on("cudnn@8.8:", when="@0.4.11:") | ||||
|         depends_on("cudnn@8.2:", when="@0.4:0.4.7") | ||||
|         depends_on("cudnn@8.0.5:") | ||||
| 
 | ||||
|     with when("+nccl"): | ||||
|         depends_on("nccl@2.18:", when="@0.4.26:") | ||||
|         depends_on("nccl@2.16:", when="@0.4.18:") | ||||
|         depends_on("nccl") | ||||
| 
 | ||||
|     with default_args(type="build"): | ||||
|         # .bazelversion | ||||
|         depends_on("bazel@6.1.2", when="@0.4.11:") | ||||
|         depends_on("bazel@5.1.1", when="@0.3.7:0.4.10") | ||||
|         depends_on("bazel@5.1.0", when="@0.3.5") | ||||
|         depends_on("bazel@5.0.0", when="@0.3.0:0.3.2") | ||||
|         depends_on("bazel@4.2.1", when="@0.1.75:0.1.76") | ||||
|         depends_on("bazel@4.1.0", when="@0.1.70:0.1.74") | ||||
| 
 | ||||
|         # jaxlib/setup.py | ||||
|         depends_on("py-setuptools") | ||||
| 
 | ||||
|         # build/build.py | ||||
|     depends_on("py-build", when="@0.4.14:", type="build") | ||||
|         depends_on("py-build", when="@0.4.14:") | ||||
| 
 | ||||
|     with default_args(type=("build", "run")): | ||||
|         # Based on PyPI wheels | ||||
|     depends_on("python@3.9:3.12", when="@0.4.17:", type=("build", "run")) | ||||
|     depends_on("python@3.9:3.11", when="@0.4.14:0.4.16", type=("build", "run")) | ||||
|     depends_on("python@3.8:3.11", when="@0.4.6:0.4.13", type=("build", "run")) | ||||
|         depends_on("python@3.9:3.12", when="@0.4.17:") | ||||
|         depends_on("python@3.9:3.11", when="@0.4.14:0.4.16") | ||||
|         depends_on("python@3.8:3.11", when="@0.4.6:0.4.13") | ||||
| 
 | ||||
|         # jaxlib/setup.py | ||||
|     depends_on("py-setuptools", type="build") | ||||
|     depends_on("py-scipy@1.9:", when="@0.4.19:", type=("build", "run")) | ||||
|     depends_on("py-scipy@1.7:", when="@0.4.7:", type=("build", "run")) | ||||
|     depends_on("py-scipy@1.5:", type=("build", "run")) | ||||
|     depends_on("py-numpy@1.22:", when="@0.4.14:", type=("build", "run")) | ||||
|     depends_on("py-numpy@1.21:", when="@0.4.7:", type=("build", "run")) | ||||
|     depends_on("py-numpy@1.20:", when="@0.3:", type=("build", "run")) | ||||
|     depends_on("py-numpy@1.18:", type=("build", "run")) | ||||
|     depends_on("py-ml-dtypes@0.2:", when="@0.4.14:", type=("build", "run")) | ||||
|     depends_on("py-ml-dtypes@0.1:", when="@0.4.9:", type=("build", "run")) | ||||
|     depends_on("py-ml-dtypes@0.0.3:", when="@0.4.7:", type=("build", "run")) | ||||
| 
 | ||||
|     # .bazelversion | ||||
|     depends_on("bazel@6.1.2", when="@0.4.11:", type="build") | ||||
|     depends_on("bazel@5.1.1", when="@0.3.7:0.4.10", type="build") | ||||
|     depends_on("bazel@5.1.0", when="@0.3.5", type="build") | ||||
|     depends_on("bazel@5.0.0", when="@0.3.0:0.3.2", type="build") | ||||
|     depends_on("bazel@4.2.1", when="@0.1.75:0.1.76", type="build") | ||||
|     depends_on("bazel@4.1.0", when="@0.1.70:0.1.74", type="build") | ||||
| 
 | ||||
|     # jaxlib/setup.py | ||||
|     depends_on("cuda@12.1.105:", when="@0.4.26:+cuda") | ||||
|     depends_on("cuda@11.8:", when="@0.4.11:+cuda") | ||||
|     depends_on("cuda@11.4:", when="@0.4.0:0.4.7+cuda") | ||||
|     depends_on("cuda@11.1:", when="@0.3+cuda") | ||||
|     # https://github.com/google/jax/issues/12614 | ||||
|     depends_on("cuda@11.1:11.7.0", when="@0.1+cuda") | ||||
| 
 | ||||
|     depends_on("cudnn@8.8:", when="@0.4.11:+cuda") | ||||
|     depends_on("cudnn@8.2:", when="@0.4:0.4.7+cuda") | ||||
|     depends_on("cudnn@8.0.5:", when="+cuda") | ||||
|         depends_on("py-scipy@1.9:", when="@0.4.19:") | ||||
|         depends_on("py-scipy@1.7:", when="@0.4.7:") | ||||
|         depends_on("py-scipy@1.5:") | ||||
|         depends_on("py-numpy@1.22:", when="@0.4.14:") | ||||
|         depends_on("py-numpy@1.21:", when="@0.4.7:") | ||||
|         depends_on("py-numpy@1.20:", when="@0.3:") | ||||
|         depends_on("py-numpy@1.18:") | ||||
|         depends_on("py-ml-dtypes@0.2:", when="@0.4.14:") | ||||
|         depends_on("py-ml-dtypes@0.1:", when="@0.4.9:") | ||||
|         depends_on("py-ml-dtypes@0.0.3:", when="@0.4.7:") | ||||
| 
 | ||||
|         # Historical dependencies | ||||
|     depends_on("py-absl-py", when="@:0.3", type=("build", "run")) | ||||
|     depends_on("py-flatbuffers@1.12:2", when="@0.1", type=("build", "run")) | ||||
|         depends_on("py-absl-py", when="@:0.3") | ||||
|         depends_on("py-flatbuffers@1.12:2", when="@0.1") | ||||
| 
 | ||||
|     conflicts( | ||||
|         "cuda_arch=none", | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Adam J. Stewart
					Adam J. Stewart