py-jax: add v0.4.3 (#35460)

* py-jax: add v0.4.3

* Minimum version is minimum

* py-jax no longer has cuda variant

* Enable CUDA by default

* Link to discussion of upper bound
This commit is contained in:
Adam J. Stewart 2023-02-21 12:14:27 -07:00 committed by GitHub
parent cddef35ef8
commit 16adda3db9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 64 additions and 60 deletions

View File

@ -37,12 +37,6 @@ class PyAlphafold(PythonPackage, CudaPackage):
depends_on("py-immutabledict@2.0.0:", type=("build", "run")) depends_on("py-immutabledict@2.0.0:", type=("build", "run"))
depends_on("py-jax@0.2.14:", type=("build", "run"), when="@2.1.1") depends_on("py-jax@0.2.14:", type=("build", "run"), when="@2.1.1")
depends_on("py-jax@0.3.17:", type=("build", "run"), when="@2.2.4") depends_on("py-jax@0.3.17:", type=("build", "run"), when="@2.2.4")
for arch in CudaPackage.cuda_arch_values:
depends_on(
"py-jax+cuda cuda_arch={0}".format(arch),
type=("build", "run"),
when="cuda_arch={0}".format(arch),
)
depends_on("py-ml-collections@0.1.0:", type=("build", "run")) depends_on("py-ml-collections@0.1.0:", type=("build", "run"))
depends_on("py-numpy@1.19.5:", type=("build", "run"), when="@2.1.1") depends_on("py-numpy@1.19.5:", type=("build", "run"), when="@2.1.1")
depends_on("py-numpy@1.21.6:", type=("build", "run"), when="@2.2.4") depends_on("py-numpy@1.21.6:", type=("build", "run"), when="@2.2.4")

View File

@ -7,7 +7,7 @@
from spack.package import * from spack.package import *
class PyJax(PythonPackage, CudaPackage): class PyJax(PythonPackage):
"""JAX is Autograd and XLA, brought together for high-performance """JAX is Autograd and XLA, brought together for high-performance
machine learning research. With its updated version of Autograd, machine learning research. With its updated version of Autograd,
JAX can automatically differentiate native Python and NumPy JAX can automatically differentiate native Python and NumPy
@ -21,29 +21,29 @@ class PyJax(PythonPackage, CudaPackage):
homepage = "https://github.com/google/jax" homepage = "https://github.com/google/jax"
pypi = "jax/jax-0.2.25.tar.gz" pypi = "jax/jax-0.2.25.tar.gz"
version("0.4.3", sha256="d43f08f940aa30eb339965cfb3d6bee2296537b0dc2f0c65ccae3009279529ae")
version("0.3.23", sha256="bff436e15552a82c0ebdef32737043b799e1e10124423c57a6ae6118c3a7b6cd") version("0.3.23", sha256="bff436e15552a82c0ebdef32737043b799e1e10124423c57a6ae6118c3a7b6cd")
version("0.2.25", sha256="822e8d1e06257eaa0fdc4c0a0686c4556e9f33647fa2a766755f984786ae7446") version("0.2.25", sha256="822e8d1e06257eaa0fdc4c0a0686c4556e9f33647fa2a766755f984786ae7446")
variant("cuda", default=True, description="CUDA support") depends_on("python@3.8:", when="@0.4:", type=("build", "run"))
depends_on("python@3.7:", type=("build", "run"))
depends_on("py-setuptools", type="build") depends_on("py-setuptools", type="build")
depends_on("py-numpy@1.18:", type=("build", "run"), when="@0.2.25") depends_on("py-numpy@1.20:", when="@0.3:", type=("build", "run"))
depends_on("py-numpy@1.20:", type=("build", "run"), when="@0.3.23")
depends_on("py-numpy@1.18:", type=("build", "run")) depends_on("py-numpy@1.18:", type=("build", "run"))
depends_on("py-absl-py", type=("build", "run"))
depends_on("py-opt-einsum", type=("build", "run")) depends_on("py-opt-einsum", type=("build", "run"))
depends_on("py-scipy@1.2.1:", type=("build", "run"), when="@0.2.25") depends_on("py-scipy@1.5:", when="@0.3:", type=("build", "run"))
depends_on("py-scipy@1.5:", type=("build", "run"), when="@0.3.23") depends_on("py-scipy@1.2.1:", type=("build", "run"))
depends_on("py-typing-extensions", type=("build", "run"))
depends_on("py-etils+epath", type=("build", "run"), when="@0.3.23") # See _minimum_jaxlib_version in jax/version.py
depends_on("py-jaxlib@0.3.15:", type=("build", "run"), when="@0.3.23~cuda") jax_to_jaxlib = {
depends_on("py-jaxlib@0.3.15:+cuda", type=("build", "run"), when="@0.3.23+cuda") "0.4.3": "0.4.2",
depends_on("py-jaxlib@0.1.69:", type=("build", "run"), when="@0.2.25~cuda") "0.3.23": "0.3.15",
depends_on("py-jaxlib@0.1.69:+cuda", type=("build", "run"), when="@0.2.25+cuda") "0.2.25": "0.1.69",
for arch in CudaPackage.cuda_arch_values: }
depends_on(
"py-jaxlib+cuda cuda_arch={0}".format(arch), for jax, jaxlib in jax_to_jaxlib.items():
type=("build", "run"), depends_on(f"py-jaxlib@{jaxlib}:", when=f"@{jax}", type=("build", "run"))
when="cuda_arch={0}".format(arch),
) # Historical dependencies
depends_on("py-absl-py", when="@:0.3", type=("build", "run"))
depends_on("py-typing-extensions", when="@:0.3", type=("build", "run"))
depends_on("py-etils+epath", when="@0.3", type=("build", "run"))

View File

@ -17,25 +17,55 @@ class PyJaxlib(PythonPackage, CudaPackage):
tmp_path = "" tmp_path = ""
buildtmp = "" buildtmp = ""
version("0.4.3", sha256="2104735dc22be2b105e5517bd5bc6ae97f40e8e9e54928cac1585c6112a3d910")
version("0.3.22", sha256="680a6f5265ba26d5515617a95ae47244005366f879a5c321782fde60f34e6d0d") version("0.3.22", sha256="680a6f5265ba26d5515617a95ae47244005366f879a5c321782fde60f34e6d0d")
version("0.1.74", sha256="bbc78c7a4927012dcb1b7cd135c7521f782d7dad516a2401b56d3190f81afe35") version("0.1.74", sha256="bbc78c7a4927012dcb1b7cd135c7521f782d7dad516a2401b56d3190f81afe35")
# see jaxlib/setup.py for dependencies variant("cuda", default=True, description="Build with CUDA")
depends_on("python@3.7:", type=("build", "run"))
depends_on("py-setuptools", type="build")
depends_on("py-numpy@1.18:", type=("build", "run"), when="@0.1.74") # jaxlib/setup.py
depends_on("py-numpy@1.20:", type=("build", "run"), when="@0.3.22") depends_on("python@3.8:", when="@0.4:", type=("build", "run"))
depends_on("py-setuptools", type="build")
depends_on("py-numpy@1.20:", when="@0.3:", type=("build", "run"))
depends_on("py-numpy@1.18:", type=("build", "run"))
depends_on("py-scipy@1.5:", type=("build", "run")) depends_on("py-scipy@1.5:", type=("build", "run"))
depends_on("py-absl-py", type=("build", "run"))
depends_on("py-flatbuffers@1.12:2", type=("build", "run"), when="@0.1.74") # .bazelversion
# Bazel 5 not yet supported: https://github.com/google/jax/issues/8440 depends_on("bazel@5.1.1:", when="@0.3:", type="build")
depends_on("bazel@4.1.0:4", type=("build"), when="@0.1.74") # https://github.com/google/jax/issues/8440
# Bazel 5 support starts here depends_on("bazel@4.1:4", when="@0.1", type="build")
depends_on("bazel@5.1.1:", type=("build"), when="@0.3.22")
# README.md
depends_on("cuda@11.4:", when="@0.4:+cuda")
depends_on("cuda@11.1:", when="@0.3+cuda")
# https://github.com/google/jax/issues/12614
depends_on("cuda@11.1:11.7.0", when="@0.1+cuda")
depends_on("cudnn@8.2:", when="@0.4:+cuda")
depends_on("cudnn@8.0.5:", when="+cuda") depends_on("cudnn@8.0.5:", when="+cuda")
depends_on("cuda@11.1:11.7.0", when="@0.1.74+cuda")
depends_on("cuda@11.1:", when="@0.3.22+cuda") # Historical dependencies
depends_on("py-absl-py", when="@:0.3", type=("build", "run"))
depends_on("py-flatbuffers@1.12:2", when="@0.1", type=("build", "run"))
def patch(self):
self.tmp_path = tempfile.mkdtemp(prefix="spack")
self.buildtmp = tempfile.mkdtemp(prefix="spack")
# triple quotes necessary because of a variety
# of other embedded quote(s)
filter_file(
"""f"--output_path={output_path}",""",
"""f"--output_path={output_path}","""
"""f"--sources_path=%s","""
"""f"--nohome_rc'","""
"""f"--nosystem_rc'",""" % self.tmp_path,
"build/build.py",
)
filter_file(
"args = parser.parse_args()",
"args,junk = parser.parse_known_args()",
"build/build_wheel.py",
string=True,
)
def install(self, spec, prefix): def install(self, spec, prefix):
args = [] args = []
@ -58,23 +88,3 @@ def install(self, spec, prefix):
pip(*args) pip(*args)
remove_linked_tree(self.wrapped_package_object.tmp_path) remove_linked_tree(self.wrapped_package_object.tmp_path)
remove_linked_tree(self.wrapped_package_object.buildtmp) remove_linked_tree(self.wrapped_package_object.buildtmp)
def patch(self):
self.tmp_path = tempfile.mkdtemp(prefix="spack")
self.buildtmp = tempfile.mkdtemp(prefix="spack")
# triple quotes necessary because of a variety
# of other embedded quote(s)
filter_file(
"""f"--output_path={output_path}",""",
"""f"--output_path={output_path}","""
"""f"--sources_path=%s","""
"""f"--nohome_rc'","""
"""f"--nosystem_rc'",""" % self.tmp_path,
"build/build.py",
)
filter_file(
"args = parser.parse_args()",
"args,junk = parser.parse_known_args()",
"build/build_wheel.py",
string=True,
)