JAX: add v0.4.27, NCCL variant (#44071)
This commit is contained in:
parent
9ab6c30a3d
commit
314893982e
@ -19,11 +19,12 @@ class PyJax(PythonPackage):
|
||||
arbitrarily to any order."""
|
||||
|
||||
homepage = "https://github.com/google/jax"
|
||||
pypi = "jax/jax-0.2.25.tar.gz"
|
||||
pypi = "jax/jax-0.4.27.tar.gz"
|
||||
|
||||
license("Apache-2.0")
|
||||
maintainers("adamjstewart", "jonas-eschle")
|
||||
|
||||
version("0.4.27", sha256="f3d7f19bdc0a17ccdb305086099a5a90c704f904d4272a70debe06ae6552998c")
|
||||
version("0.4.26", sha256="2cce025d0a279ec630d550524749bc8efe25d2ff47240d2a7d4cfbc5090c5383")
|
||||
version("0.4.25", sha256="a8ee189c782de2b7b2ffb64a8916da380b882a617e2769aa429b71d79747b982")
|
||||
version("0.4.24", sha256="4a6b6fd026ddd22653c7fa2fac1904c3de2dbe845b61ede08af9a5cc709662ae")
|
||||
@ -59,74 +60,79 @@ class PyJax(PythonPackage):
|
||||
deprecated=True,
|
||||
)
|
||||
|
||||
depends_on("python@3.9:", when="@0.4.14:", type=("build", "run"))
|
||||
depends_on("python@3.8:", when="@0.4:", type=("build", "run"))
|
||||
depends_on("py-setuptools", type="build")
|
||||
depends_on("py-ml-dtypes@0.2:", when="@0.4.14:", type=("build", "run"))
|
||||
depends_on("py-ml-dtypes@0.1:", when="@0.4.9:", type=("build", "run"))
|
||||
depends_on("py-ml-dtypes@0.0.3:", when="@0.4.7:", type=("build", "run"))
|
||||
depends_on("py-numpy@1.22:", when="@0.4.14:", type=("build", "run"))
|
||||
depends_on("py-numpy@1.21:", when="@0.4.7:", type=("build", "run"))
|
||||
depends_on("py-numpy@1.20:", when="@0.3:", type=("build", "run"))
|
||||
depends_on("py-numpy@1.18:", type=("build", "run"))
|
||||
depends_on("py-opt-einsum", type=("build", "run"))
|
||||
depends_on("py-scipy@1.9:", when="@0.4.19:", type=("build", "run"))
|
||||
depends_on("py-scipy@1.7:", when="@0.4.7:", type=("build", "run"))
|
||||
depends_on("py-scipy@1.5:", when="@0.3:", type=("build", "run"))
|
||||
depends_on("py-scipy@1.2.1:", type=("build", "run"))
|
||||
depends_on("py-importlib-metadata@4.6:", when="@0.4.11: ^python@:3.9", type=("build", "run"))
|
||||
|
||||
# See jax/_src/lib/__init__.py
|
||||
# https://github.com/google/jax/commit/8be057de1f50756fe7522f7e98b2f30fad56f7e4
|
||||
for v in [
|
||||
"0.4.26",
|
||||
"0.4.25",
|
||||
"0.4.24",
|
||||
"0.4.23",
|
||||
"0.4.22",
|
||||
"0.4.21",
|
||||
"0.4.20",
|
||||
"0.4.19",
|
||||
"0.4.18",
|
||||
"0.4.17",
|
||||
"0.4.16",
|
||||
"0.4.15",
|
||||
"0.4.14",
|
||||
"0.4.13",
|
||||
"0.4.12",
|
||||
"0.4.11",
|
||||
"0.4.10",
|
||||
"0.4.9",
|
||||
"0.4.8",
|
||||
"0.4.7",
|
||||
"0.4.6",
|
||||
"0.4.5",
|
||||
"0.4.4",
|
||||
"0.4.3",
|
||||
"0.3.23",
|
||||
]:
|
||||
depends_on(f"py-jaxlib@:{v}", when=f"@{v}", type=("build", "run"))
|
||||
with default_args(type=("build", "run")):
|
||||
# setup.py
|
||||
depends_on("python@3.9:", when="@0.4.14:")
|
||||
depends_on("python@3.8:", when="@0.4:")
|
||||
depends_on("py-ml-dtypes@0.2:", when="@0.4.14:")
|
||||
depends_on("py-ml-dtypes@0.1:", when="@0.4.9:")
|
||||
depends_on("py-ml-dtypes@0.0.3:", when="@0.4.7:")
|
||||
depends_on("py-numpy@1.22:", when="@0.4.14:")
|
||||
depends_on("py-numpy@1.21:", when="@0.4.7:")
|
||||
depends_on("py-numpy@1.20:", when="@0.3:")
|
||||
depends_on("py-numpy@1.18:")
|
||||
depends_on("py-opt-einsum")
|
||||
depends_on("py-scipy@1.9:", when="@0.4.19:")
|
||||
depends_on("py-scipy@1.7:", when="@0.4.7:")
|
||||
depends_on("py-scipy@1.5:", when="@0.3:")
|
||||
depends_on("py-scipy@1.2.1:")
|
||||
depends_on("py-importlib-metadata@4.6:", when="@0.4.11: ^python@:3.9")
|
||||
|
||||
# See _minimum_jaxlib_version in jax/version.py
|
||||
depends_on("py-jaxlib@0.4.20:", when="@0.4.25:", type=("build", "run"))
|
||||
depends_on("py-jaxlib@0.4.19:", when="@0.4.21:", type=("build", "run"))
|
||||
depends_on("py-jaxlib@0.4.14:", when="@0.4.15:", type=("build", "run"))
|
||||
depends_on("py-jaxlib@0.4.11:", when="@0.4.12:", type=("build", "run"))
|
||||
depends_on("py-jaxlib@0.4.7:", when="@0.4.8:", type=("build", "run"))
|
||||
depends_on("py-jaxlib@0.4.6:", when="@0.4.7:", type=("build", "run"))
|
||||
depends_on("py-jaxlib@0.4.4:", when="@0.4.5:", type=("build", "run"))
|
||||
depends_on("py-jaxlib@0.4.2:", when="@0.4.3:", type=("build", "run"))
|
||||
depends_on("py-jaxlib@0.4.1:", when="@0.4.2:", type=("build", "run"))
|
||||
depends_on("py-jaxlib@0.3.22:", when="@0.3.24:", type=("build", "run"))
|
||||
depends_on("py-jaxlib@0.3.15:", when="@0.3.18:", type=("build", "run"))
|
||||
depends_on("py-jaxlib@0.3.14:", when="@0.3.15:", type=("build", "run"))
|
||||
depends_on("py-jaxlib@0.3.7:", when="@0.3.8:", type=("build", "run"))
|
||||
depends_on("py-jaxlib@0.3.2:", when="@0.3.7:", type=("build", "run"))
|
||||
depends_on("py-jaxlib@0.3.0:", when="@0.3.2:", type=("build", "run"))
|
||||
depends_on("py-jaxlib@0.1.74:", when="@0.2.26:", type=("build", "run"))
|
||||
depends_on("py-jaxlib@0.1.69:", when="@0.2.18:", type=("build", "run"))
|
||||
# jax/_src/lib/__init__.py
|
||||
# https://github.com/google/jax/commit/8be057de1f50756fe7522f7e98b2f30fad56f7e4
|
||||
for v in [
|
||||
"0.4.27",
|
||||
"0.4.26",
|
||||
"0.4.25",
|
||||
"0.4.24",
|
||||
"0.4.23",
|
||||
"0.4.22",
|
||||
"0.4.21",
|
||||
"0.4.20",
|
||||
"0.4.19",
|
||||
"0.4.18",
|
||||
"0.4.17",
|
||||
"0.4.16",
|
||||
"0.4.15",
|
||||
"0.4.14",
|
||||
"0.4.13",
|
||||
"0.4.12",
|
||||
"0.4.11",
|
||||
"0.4.10",
|
||||
"0.4.9",
|
||||
"0.4.8",
|
||||
"0.4.7",
|
||||
"0.4.6",
|
||||
"0.4.5",
|
||||
"0.4.4",
|
||||
"0.4.3",
|
||||
"0.3.23",
|
||||
]:
|
||||
depends_on(f"py-jaxlib@:{v}", when=f"@{v}")
|
||||
|
||||
# Historical dependencies
|
||||
depends_on("py-absl-py", when="@:0.3", type=("build", "run"))
|
||||
depends_on("py-typing-extensions", when="@:0.3", type=("build", "run"))
|
||||
depends_on("py-etils+epath", when="@0.3", type=("build", "run"))
|
||||
# See _minimum_jaxlib_version in jax/version.py
|
||||
depends_on("py-jaxlib@0.4.23:", when="@0.4.27:")
|
||||
depends_on("py-jaxlib@0.4.20:", when="@0.4.25:")
|
||||
depends_on("py-jaxlib@0.4.19:", when="@0.4.21:")
|
||||
depends_on("py-jaxlib@0.4.14:", when="@0.4.15:")
|
||||
depends_on("py-jaxlib@0.4.11:", when="@0.4.12:")
|
||||
depends_on("py-jaxlib@0.4.7:", when="@0.4.8:")
|
||||
depends_on("py-jaxlib@0.4.6:", when="@0.4.7:")
|
||||
depends_on("py-jaxlib@0.4.4:", when="@0.4.5:")
|
||||
depends_on("py-jaxlib@0.4.2:", when="@0.4.3:")
|
||||
depends_on("py-jaxlib@0.4.1:", when="@0.4.2:")
|
||||
depends_on("py-jaxlib@0.3.22:", when="@0.3.24:")
|
||||
depends_on("py-jaxlib@0.3.15:", when="@0.3.18:")
|
||||
depends_on("py-jaxlib@0.3.14:", when="@0.3.15:")
|
||||
depends_on("py-jaxlib@0.3.7:", when="@0.3.8:")
|
||||
depends_on("py-jaxlib@0.3.2:", when="@0.3.7:")
|
||||
depends_on("py-jaxlib@0.3.0:", when="@0.3.2:")
|
||||
depends_on("py-jaxlib@0.1.74:", when="@0.2.26:")
|
||||
depends_on("py-jaxlib@0.1.69:", when="@0.2.18:")
|
||||
|
||||
# Historical dependencies
|
||||
depends_on("py-absl-py", when="@:0.3")
|
||||
depends_on("py-typing-extensions", when="@:0.3")
|
||||
depends_on("py-etils+epath", when="@0.3")
|
||||
|
@ -12,7 +12,7 @@ class PyJaxlib(PythonPackage, CudaPackage):
|
||||
"""XLA library for Jax"""
|
||||
|
||||
homepage = "https://github.com/google/jax"
|
||||
url = "https://github.com/google/jax/archive/refs/tags/jaxlib-v0.1.74.tar.gz"
|
||||
url = "https://github.com/google/jax/archive/refs/tags/jaxlib-v0.4.27.tar.gz"
|
||||
|
||||
tmp_path = ""
|
||||
buildtmp = ""
|
||||
@ -20,6 +20,7 @@ class PyJaxlib(PythonPackage, CudaPackage):
|
||||
license("Apache-2.0")
|
||||
maintainers("adamjstewart")
|
||||
|
||||
version("0.4.27", sha256="c2c82cd9ad3b395d5cbc0affa26a2938e52677a69ca8f0b9ef9922a52cac4f0c")
|
||||
version("0.4.26", sha256="ddc14da1eaa34f23430d40ad9b9585088575cac439a2fa1c6833a247e1b221fd")
|
||||
version("0.4.25", sha256="fc1197c401924942eb14185a61688d0c476e3e81ff71f9dc95e620b57c06eec8")
|
||||
version("0.4.24", sha256="c4e6963c2c36f634a9a1765e476a1ed4e6c4a7954465ebf72e29f344c28ddc28")
|
||||
@ -45,52 +46,63 @@ class PyJaxlib(PythonPackage, CudaPackage):
|
||||
deprecated=True,
|
||||
)
|
||||
|
||||
variant("cuda", default=True, description="Build with CUDA")
|
||||
|
||||
# build/build.py
|
||||
depends_on("py-build", when="@0.4.14:", type="build")
|
||||
|
||||
# Based on PyPI wheels
|
||||
depends_on("python@3.9:3.12", when="@0.4.17:", type=("build", "run"))
|
||||
depends_on("python@3.9:3.11", when="@0.4.14:0.4.16", type=("build", "run"))
|
||||
depends_on("python@3.8:3.11", when="@0.4.6:0.4.13", type=("build", "run"))
|
||||
variant("cuda", default=True, description="Build with CUDA enabled")
|
||||
variant("nccl", default=True, description="Build with NCCL enabled", when="+cuda")
|
||||
|
||||
# docs/installation.md
|
||||
# jaxlib/setup.py
|
||||
depends_on("py-setuptools", type="build")
|
||||
depends_on("py-scipy@1.9:", when="@0.4.19:", type=("build", "run"))
|
||||
depends_on("py-scipy@1.7:", when="@0.4.7:", type=("build", "run"))
|
||||
depends_on("py-scipy@1.5:", type=("build", "run"))
|
||||
depends_on("py-numpy@1.22:", when="@0.4.14:", type=("build", "run"))
|
||||
depends_on("py-numpy@1.21:", when="@0.4.7:", type=("build", "run"))
|
||||
depends_on("py-numpy@1.20:", when="@0.3:", type=("build", "run"))
|
||||
depends_on("py-numpy@1.18:", type=("build", "run"))
|
||||
depends_on("py-ml-dtypes@0.2:", when="@0.4.14:", type=("build", "run"))
|
||||
depends_on("py-ml-dtypes@0.1:", when="@0.4.9:", type=("build", "run"))
|
||||
depends_on("py-ml-dtypes@0.0.3:", when="@0.4.7:", type=("build", "run"))
|
||||
with when("+cuda"):
|
||||
depends_on("cuda@12.1:", when="@0.4.26:")
|
||||
depends_on("cuda@11.8:", when="@0.4.11:")
|
||||
depends_on("cuda@11.4:", when="@0.4.0:0.4.7")
|
||||
depends_on("cuda@11.1:", when="@0.3")
|
||||
depends_on("cuda@11.1:11.7.0", when="@0.1")
|
||||
depends_on("cudnn@8.9:8", when="@0.4.26:")
|
||||
depends_on("cudnn@8.8:", when="@0.4.11:")
|
||||
depends_on("cudnn@8.2:", when="@0.4:0.4.7")
|
||||
depends_on("cudnn@8.0.5:")
|
||||
|
||||
# .bazelversion
|
||||
depends_on("bazel@6.1.2", when="@0.4.11:", type="build")
|
||||
depends_on("bazel@5.1.1", when="@0.3.7:0.4.10", type="build")
|
||||
depends_on("bazel@5.1.0", when="@0.3.5", type="build")
|
||||
depends_on("bazel@5.0.0", when="@0.3.0:0.3.2", type="build")
|
||||
depends_on("bazel@4.2.1", when="@0.1.75:0.1.76", type="build")
|
||||
depends_on("bazel@4.1.0", when="@0.1.70:0.1.74", type="build")
|
||||
with when("+nccl"):
|
||||
depends_on("nccl@2.18:", when="@0.4.26:")
|
||||
depends_on("nccl@2.16:", when="@0.4.18:")
|
||||
depends_on("nccl")
|
||||
|
||||
# jaxlib/setup.py
|
||||
depends_on("cuda@12.1.105:", when="@0.4.26:+cuda")
|
||||
depends_on("cuda@11.8:", when="@0.4.11:+cuda")
|
||||
depends_on("cuda@11.4:", when="@0.4.0:0.4.7+cuda")
|
||||
depends_on("cuda@11.1:", when="@0.3+cuda")
|
||||
# https://github.com/google/jax/issues/12614
|
||||
depends_on("cuda@11.1:11.7.0", when="@0.1+cuda")
|
||||
with default_args(type="build"):
|
||||
# .bazelversion
|
||||
depends_on("bazel@6.1.2", when="@0.4.11:")
|
||||
depends_on("bazel@5.1.1", when="@0.3.7:0.4.10")
|
||||
depends_on("bazel@5.1.0", when="@0.3.5")
|
||||
depends_on("bazel@5.0.0", when="@0.3.0:0.3.2")
|
||||
depends_on("bazel@4.2.1", when="@0.1.75:0.1.76")
|
||||
depends_on("bazel@4.1.0", when="@0.1.70:0.1.74")
|
||||
|
||||
depends_on("cudnn@8.8:", when="@0.4.11:+cuda")
|
||||
depends_on("cudnn@8.2:", when="@0.4:0.4.7+cuda")
|
||||
depends_on("cudnn@8.0.5:", when="+cuda")
|
||||
# jaxlib/setup.py
|
||||
depends_on("py-setuptools")
|
||||
|
||||
# Historical dependencies
|
||||
depends_on("py-absl-py", when="@:0.3", type=("build", "run"))
|
||||
depends_on("py-flatbuffers@1.12:2", when="@0.1", type=("build", "run"))
|
||||
# build/build.py
|
||||
depends_on("py-build", when="@0.4.14:")
|
||||
|
||||
with default_args(type=("build", "run")):
|
||||
# Based on PyPI wheels
|
||||
depends_on("python@3.9:3.12", when="@0.4.17:")
|
||||
depends_on("python@3.9:3.11", when="@0.4.14:0.4.16")
|
||||
depends_on("python@3.8:3.11", when="@0.4.6:0.4.13")
|
||||
|
||||
# jaxlib/setup.py
|
||||
depends_on("py-scipy@1.9:", when="@0.4.19:")
|
||||
depends_on("py-scipy@1.7:", when="@0.4.7:")
|
||||
depends_on("py-scipy@1.5:")
|
||||
depends_on("py-numpy@1.22:", when="@0.4.14:")
|
||||
depends_on("py-numpy@1.21:", when="@0.4.7:")
|
||||
depends_on("py-numpy@1.20:", when="@0.3:")
|
||||
depends_on("py-numpy@1.18:")
|
||||
depends_on("py-ml-dtypes@0.2:", when="@0.4.14:")
|
||||
depends_on("py-ml-dtypes@0.1:", when="@0.4.9:")
|
||||
depends_on("py-ml-dtypes@0.0.3:", when="@0.4.7:")
|
||||
|
||||
# Historical dependencies
|
||||
depends_on("py-absl-py", when="@:0.3")
|
||||
depends_on("py-flatbuffers@1.12:2", when="@0.1")
|
||||
|
||||
conflicts(
|
||||
"cuda_arch=none",
|
||||
|
Loading…
Reference in New Issue
Block a user