py-jax / JAX: add v0.4.31 (#45519)
This commit is contained in:
parent
cee266046b
commit
705d58005d
@ -24,6 +24,7 @@ class PyJax(PythonPackage):
|
||||
license("Apache-2.0")
|
||||
maintainers("adamjstewart", "jonas-eschle")
|
||||
|
||||
version("0.4.31", sha256="fd2d470643a0073d822737f0788f71391656af7e62cc5b2e7995ee390ceac287")
|
||||
version("0.4.30", sha256="94d74b5b2db0d80672b61d83f1f63ebf99d2ab7398ec12b2ca0c9d1e97afe577")
|
||||
version("0.4.29", sha256="12904571eaefddcdc8c3b8d4936482b783d5a216e99ef5adcd3522fdfb4fc186")
|
||||
version("0.4.28", sha256="dcf0a44aff2e1713f0a2b369281cd5b79d8c18fc1018905c4125897cb06b37e9")
|
||||
@ -57,25 +58,27 @@ class PyJax(PythonPackage):
|
||||
|
||||
with default_args(type=("build", "run")):
|
||||
# setup.py
|
||||
depends_on("python@3.10:", when="@0.4.31:")
|
||||
depends_on("python@3.9:", when="@0.4.14:")
|
||||
depends_on("py-ml-dtypes@0.4:", when="@0.4.29")
|
||||
depends_on("py-ml-dtypes@0.2:", when="@0.4.14:")
|
||||
depends_on("py-ml-dtypes@0.1:", when="@0.4.9:")
|
||||
depends_on("py-ml-dtypes@0.0.3:", when="@0.4.7:")
|
||||
depends_on("py-numpy@1.24:", when="@0.4.31:")
|
||||
depends_on("py-numpy@1.22:", when="@0.4.14:")
|
||||
depends_on("py-numpy@1.21:", when="@0.4.7:")
|
||||
depends_on("py-numpy@1.20:", when="@0.3:")
|
||||
# https://github.com/google/jax/issues/19246
|
||||
depends_on("py-numpy@:1", when="@:0.4.25")
|
||||
depends_on("py-opt-einsum")
|
||||
depends_on("py-scipy@1.10:", when="@0.4.31:")
|
||||
depends_on("py-scipy@1.9:", when="@0.4.19:")
|
||||
depends_on("py-scipy@1.7:", when="@0.4.7:")
|
||||
depends_on("py-scipy@1.5:", when="@0.3:")
|
||||
depends_on("py-importlib-metadata@4.6:", when="@0.4.11: ^python@:3.9")
|
||||
|
||||
# jax/_src/lib/__init__.py
|
||||
# https://github.com/google/jax/commit/8be057de1f50756fe7522f7e98b2f30fad56f7e4
|
||||
for v in [
|
||||
"0.4.31",
|
||||
"0.4.30",
|
||||
"0.4.29",
|
||||
"0.4.28",
|
||||
@ -108,6 +111,7 @@ class PyJax(PythonPackage):
|
||||
depends_on(f"py-jaxlib@:{v}", when=f"@{v}")
|
||||
|
||||
# See _minimum_jaxlib_version in jax/version.py
|
||||
depends_on("py-jaxlib@0.4.30:", when="@0.4.31:")
|
||||
depends_on("py-jaxlib@0.4.27:", when="@0.4.28:")
|
||||
depends_on("py-jaxlib@0.4.23:", when="@0.4.27:")
|
||||
depends_on("py-jaxlib@0.4.20:", when="@0.4.25:")
|
||||
@ -119,3 +123,7 @@ class PyJax(PythonPackage):
|
||||
depends_on("py-jaxlib@0.4.4:", when="@0.4.5:")
|
||||
depends_on("py-jaxlib@0.4.2:", when="@0.4.3:")
|
||||
depends_on("py-jaxlib@0.4.1:", when="@0.4.2:")
|
||||
|
||||
# Historical dependencies
|
||||
depends_on("py-ml-dtypes@0.4:", when="@0.4.29")
|
||||
depends_on("py-importlib-metadata@4.6:", when="@0.4.11:0.4.30 ^python@:3.9")
|
||||
|
@ -20,6 +20,7 @@ class PyJaxlib(PythonPackage, CudaPackage):
|
||||
license("Apache-2.0")
|
||||
maintainers("adamjstewart", "jonas-eschle")
|
||||
|
||||
version("0.4.31", sha256="022ea1347f9b21cbea31410b3d650d976ea4452a48ea7317a5f91c238031bf94")
|
||||
version("0.4.30", sha256="0ef9635c734d9bbb44fcc87df4f1c3ccce1cfcfd243572c80d36fcdf826fe1e6")
|
||||
version("0.4.29", sha256="3a8005f4f62d35a5aad7e3dbd596890b47c81cc6e34fcfe3dcb93b3ca7cb1246")
|
||||
version("0.4.28", sha256="4dd11577d4ba5a095fbc35258ddd4e4c020829ed6e6afd498c9e38ccbcdfe20b")
|
||||
@ -39,18 +40,19 @@ class PyJaxlib(PythonPackage, CudaPackage):
|
||||
version("0.4.4", sha256="881f402c7983b56b185e182d5315dd64c9f5320be96213d0415996ece1826806")
|
||||
version("0.4.3", sha256="2104735dc22be2b105e5517bd5bc6ae97f40e8e9e54928cac1585c6112a3d910")
|
||||
|
||||
depends_on("c", type="build") # generated
|
||||
depends_on("cxx", type="build") # generated
|
||||
depends_on("c", type="build")
|
||||
depends_on("cxx", type="build")
|
||||
|
||||
variant("cuda", default=True, description="Build with CUDA enabled")
|
||||
variant("nccl", default=True, description="Build with NCCL enabled", when="+cuda")
|
||||
|
||||
# docs/installation.md
|
||||
# docs/installation.md (Compatible with)
|
||||
with when("+cuda"):
|
||||
depends_on("cuda@12.1:", when="@0.4.26:")
|
||||
depends_on("cuda@11.8:", when="@0.4.11:")
|
||||
depends_on("cuda@11.4:", when="@0.4.0:0.4.7")
|
||||
depends_on("cudnn@9", when="@0.4.29:")
|
||||
depends_on("cudnn@9.1:9", when="@0.4.31:")
|
||||
depends_on("cudnn@9", when="@0.4.29:0.4.30")
|
||||
depends_on("cudnn@8.9:8", when="@0.4.26:0.4.28")
|
||||
depends_on("cudnn@8.8:8", when="@0.4.11:0.4.25")
|
||||
depends_on("cudnn@8.2:8", when="@0.4:0.4.7")
|
||||
@ -74,24 +76,29 @@ class PyJaxlib(PythonPackage, CudaPackage):
|
||||
|
||||
with default_args(type=("build", "run")):
|
||||
# Based on PyPI wheels
|
||||
depends_on("python@3.9:3.12", when="@0.4.17:")
|
||||
depends_on("python@3.10:3.12", when="@0.4.31:")
|
||||
depends_on("python@3.9:3.12", when="@0.4.17:0.4.30")
|
||||
depends_on("python@3.9:3.11", when="@0.4.14:0.4.16")
|
||||
depends_on("python@3.8:3.11", when="@0.4.6:0.4.13")
|
||||
|
||||
# jaxlib/setup.py
|
||||
depends_on("py-scipy@1.10:", when="@0.4.31:")
|
||||
depends_on("py-scipy@1.9:", when="@0.4.19:")
|
||||
depends_on("py-scipy@1.7:", when="@0.4.7:")
|
||||
depends_on("py-scipy@1.5:")
|
||||
depends_on("py-numpy@1.24:", when="@0.4.31:")
|
||||
depends_on("py-numpy@1.22:", when="@0.4.14:")
|
||||
depends_on("py-numpy@1.21:", when="@0.4.7:")
|
||||
depends_on("py-numpy@1.20:", when="@0.3:")
|
||||
# https://github.com/google/jax/issues/19246
|
||||
depends_on("py-numpy@:1", when="@:0.4.25")
|
||||
depends_on("py-ml-dtypes@0.4:", when="@0.4.29")
|
||||
depends_on("py-ml-dtypes@0.2:", when="@0.4.14:")
|
||||
depends_on("py-ml-dtypes@0.1:", when="@0.4.9:")
|
||||
depends_on("py-ml-dtypes@0.0.3:", when="@0.4.7:")
|
||||
|
||||
# Historical dependencies
|
||||
# https://github.com/google/jax/issues/19246
|
||||
depends_on("py-numpy@:1", when="@:0.4.25")
|
||||
depends_on("py-ml-dtypes@0.4:", when="@0.4.29")
|
||||
|
||||
conflicts(
|
||||
"cuda_arch=none",
|
||||
when="+cuda",
|
||||
|
Loading…
Reference in New Issue
Block a user