JAX: add v0.4.29 (#44683)

Co-authored-by: adamjstewart <adamjstewart@users.noreply.github.com>
This commit is contained in:
Adam J. Stewart 2024-06-24 08:43:45 +02:00 committed by GitHub
parent c1f1e1396d
commit 910b923c5d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 31 additions and 14 deletions

View File

@ -9,6 +9,14 @@
from spack.package import * from spack.package import *
_versions = { _versions = {
# cuDNN 9.2.0
"9.2.0.82-12": {
"Linux-x86_64": "1362b4d437e37e92c9814c3b4065db5106c2e03268e22275a5869e968cee7aa8",
"Linux-aarch64": "24cc2a0308dfe412c02c7d41d4b07ec12dacb021ebf8c719de38eb77d22f68c1",
},
"9.2.0.82-11": {
"Linux-x86_64": "99dcb3fa2bf7eed7f35b0f8e58e7d1f04d9a52e01e382efc1de16fed230d3b26"
},
# cuDNN 8.9.7 # cuDNN 8.9.7
"8.9.7.29-12": { "8.9.7.29-12": {
"Linux-x86_64": "475333625c7e42a7af3ca0b2f7506a106e30c93b1aa0081cd9c13efb6e21e3bb", "Linux-x86_64": "475333625c7e42a7af3ca0b2f7506a106e30c93b1aa0081cd9c13efb6e21e3bb",

View File

@ -24,6 +24,7 @@ class PyJax(PythonPackage):
license("Apache-2.0") license("Apache-2.0")
maintainers("adamjstewart", "jonas-eschle") maintainers("adamjstewart", "jonas-eschle")
version("0.4.29", sha256="12904571eaefddcdc8c3b8d4936482b783d5a216e99ef5adcd3522fdfb4fc186")
version("0.4.28", sha256="dcf0a44aff2e1713f0a2b369281cd5b79d8c18fc1018905c4125897cb06b37e9") version("0.4.28", sha256="dcf0a44aff2e1713f0a2b369281cd5b79d8c18fc1018905c4125897cb06b37e9")
version("0.4.27", sha256="f3d7f19bdc0a17ccdb305086099a5a90c704f904d4272a70debe06ae6552998c") version("0.4.27", sha256="f3d7f19bdc0a17ccdb305086099a5a90c704f904d4272a70debe06ae6552998c")
version("0.4.26", sha256="2cce025d0a279ec630d550524749bc8efe25d2ff47240d2a7d4cfbc5090c5383") version("0.4.26", sha256="2cce025d0a279ec630d550524749bc8efe25d2ff47240d2a7d4cfbc5090c5383")
@ -56,6 +57,7 @@ class PyJax(PythonPackage):
with default_args(type=("build", "run")): with default_args(type=("build", "run")):
# setup.py # setup.py
depends_on("python@3.9:", when="@0.4.14:") depends_on("python@3.9:", when="@0.4.14:")
depends_on("py-ml-dtypes@0.4:", when="@0.4.29:")
depends_on("py-ml-dtypes@0.2:", when="@0.4.14:") depends_on("py-ml-dtypes@0.2:", when="@0.4.14:")
depends_on("py-ml-dtypes@0.1:", when="@0.4.9:") depends_on("py-ml-dtypes@0.1:", when="@0.4.9:")
depends_on("py-ml-dtypes@0.0.3:", when="@0.4.7:") depends_on("py-ml-dtypes@0.0.3:", when="@0.4.7:")
@ -71,6 +73,7 @@ class PyJax(PythonPackage):
# jax/_src/lib/__init__.py # jax/_src/lib/__init__.py
# https://github.com/google/jax/commit/8be057de1f50756fe7522f7e98b2f30fad56f7e4 # https://github.com/google/jax/commit/8be057de1f50756fe7522f7e98b2f30fad56f7e4
for v in [ for v in [
"0.4.29",
"0.4.28", "0.4.28",
"0.4.27", "0.4.27",
"0.4.26", "0.4.26",

View File

@ -18,8 +18,9 @@ class PyJaxlib(PythonPackage, CudaPackage):
buildtmp = "" buildtmp = ""
license("Apache-2.0") license("Apache-2.0")
maintainers("adamjstewart") maintainers("adamjstewart", "jonas-eschle")
version("0.4.29", sha256="3a8005f4f62d35a5aad7e3dbd596890b47c81cc6e34fcfe3dcb93b3ca7cb1246")
version("0.4.28", sha256="4dd11577d4ba5a095fbc35258ddd4e4c020829ed6e6afd498c9e38ccbcdfe20b") version("0.4.28", sha256="4dd11577d4ba5a095fbc35258ddd4e4c020829ed6e6afd498c9e38ccbcdfe20b")
version("0.4.27", sha256="c2c82cd9ad3b395d5cbc0affa26a2938e52677a69ca8f0b9ef9922a52cac4f0c") version("0.4.27", sha256="c2c82cd9ad3b395d5cbc0affa26a2938e52677a69ca8f0b9ef9922a52cac4f0c")
version("0.4.26", sha256="ddc14da1eaa34f23430d40ad9b9585088575cac439a2fa1c6833a247e1b221fd") version("0.4.26", sha256="ddc14da1eaa34f23430d40ad9b9585088575cac439a2fa1c6833a247e1b221fd")
@ -46,9 +47,10 @@ class PyJaxlib(PythonPackage, CudaPackage):
depends_on("cuda@12.1:", when="@0.4.26:") depends_on("cuda@12.1:", when="@0.4.26:")
depends_on("cuda@11.8:", when="@0.4.11:") depends_on("cuda@11.8:", when="@0.4.11:")
depends_on("cuda@11.4:", when="@0.4.0:0.4.7") depends_on("cuda@11.4:", when="@0.4.0:0.4.7")
depends_on("cudnn@8.9:8", when="@0.4.26:") depends_on("cudnn@9", when="@0.4.29:")
depends_on("cudnn@8.8:", when="@0.4.11:") depends_on("cudnn@8.9:8", when="@0.4.26:0.4.28")
depends_on("cudnn@8.2:", when="@0.4:0.4.7") depends_on("cudnn@8.8:8", when="@0.4.11:0.4.25")
depends_on("cudnn@8.2:8", when="@0.4:0.4.7")
with when("+nccl"): with when("+nccl"):
depends_on("nccl@2.18:", when="@0.4.26:") depends_on("nccl@2.18:", when="@0.4.26:")
@ -80,6 +82,7 @@ class PyJaxlib(PythonPackage, CudaPackage):
depends_on("py-numpy@1.22:", when="@0.4.14:") depends_on("py-numpy@1.22:", when="@0.4.14:")
depends_on("py-numpy@1.21:", when="@0.4.7:") depends_on("py-numpy@1.21:", when="@0.4.7:")
depends_on("py-numpy@1.20:", when="@0.3:") depends_on("py-numpy@1.20:", when="@0.3:")
depends_on("py-ml-dtypes@0.4:", when="@0.4.29:")
depends_on("py-ml-dtypes@0.2:", when="@0.4.14:") depends_on("py-ml-dtypes@0.2:", when="@0.4.14:")
depends_on("py-ml-dtypes@0.1:", when="@0.4.9:") depends_on("py-ml-dtypes@0.1:", when="@0.4.9:")
depends_on("py-ml-dtypes@0.0.3:", when="@0.4.7:") depends_on("py-ml-dtypes@0.0.3:", when="@0.4.7:")

View File

@ -17,11 +17,13 @@ class PyMlDtypes(PythonPackage):
license("Apache-2.0") license("Apache-2.0")
version("0.4.0", tag="v0.4.0", commit="9fc7e6773acb66fa496ed8d476a008a489a4da49")
version("0.3.1", tag="v0.3.1", commit="bbeedd470ecac727c42e97648c0f27bfc312af30") version("0.3.1", tag="v0.3.1", commit="bbeedd470ecac727c42e97648c0f27bfc312af30")
version("0.2.0", tag="v0.2.0", commit="5b9fc9ad978757654843f4a8d899715dbea30e88") version("0.2.0", tag="v0.2.0", commit="5b9fc9ad978757654843f4a8d899715dbea30e88")
depends_on("python@3.9:", when="@0.3:", type=("build", "link", "run")) depends_on("python@3.9:", when="@0.3:", type=("build", "link", "run"))
depends_on("py-numpy@1.21:", type=("build", "link", "run")) depends_on("py-numpy@1.21:", when="@0.4:", type=("build", "link", "run"))
depends_on("py-numpy@1.21:1", when="@:0.3", type=("build", "link", "run"))
# Build dependencies are overconstrained, older versions work just fine # Build dependencies are overconstrained, older versions work just fine
depends_on("py-pybind11", type=("build", "link")) depends_on("py-pybind11", when="@:0.3.1", type=("build", "link"))
depends_on("py-setuptools", type="build") depends_on("py-setuptools", type="build")

View File

@ -305,12 +305,12 @@ class PyTensorflow(Package, CudaPackage, ROCmPackage, PythonExtension):
depends_on("cuda@:11.4", when="@2.4:2.7") depends_on("cuda@:11.4", when="@2.4:2.7")
depends_on("cuda@:10.2", when="@:2.3") depends_on("cuda@:10.2", when="@:2.3")
depends_on("cudnn@8.9:", when="@2.15:") depends_on("cudnn@8.9:8", when="@2.15:")
depends_on("cudnn@8.7:", when="@2.14:") depends_on("cudnn@8.7:8", when="@2.14:")
depends_on("cudnn@8.6:", when="@2.12:") depends_on("cudnn@8.6:8", when="@2.12:")
depends_on("cudnn@8.1:", when="@2.5:") depends_on("cudnn@8.1:8", when="@2.5:")
depends_on("cudnn@8.0:", when="@2.4:") depends_on("cudnn@8.0:8", when="@2.4:")
depends_on("cudnn@7.6:", when="@2.1:") depends_on("cudnn@7.6:8", when="@2.1:")
depends_on("cudnn@:7", when="@:2.2") depends_on("cudnn@:7", when="@:2.2")
# depends_on('tensorrt', when='+tensorrt') # depends_on('tensorrt', when='+tensorrt')

View File

@ -247,8 +247,9 @@ class PyTorch(PythonPackage, CudaPackage, ROCmPackage):
depends_on("cuda@9.2:11.4", when="@1.6:1.9+cuda") depends_on("cuda@9.2:11.4", when="@1.6:1.9+cuda")
depends_on("cuda@9:11.4", when="@:1.5+cuda") depends_on("cuda@9:11.4", when="@:1.5+cuda")
# https://github.com/pytorch/pytorch#prerequisites # https://github.com/pytorch/pytorch#prerequisites
depends_on("cudnn@8.5:", when="@2.3:+cudnn") # https://github.com/pytorch/pytorch/issues/119400
depends_on("cudnn@7:", when="@1.6:+cudnn") depends_on("cudnn@8.5:9.0", when="@2.3:+cudnn")
depends_on("cudnn@7:8", when="@1.6:2.2+cudnn")
depends_on("cudnn@7", when="@:1.5+cudnn") depends_on("cudnn@7", when="@:1.5+cudnn")
depends_on("magma+cuda", when="+magma+cuda") depends_on("magma+cuda", when="+magma+cuda")
depends_on("magma+rocm", when="+magma+rocm") depends_on("magma+rocm", when="+magma+rocm")