JAX: add v0.4.32+ (#46346)

* JAX: add v0.4.34

* Disable search for clang

* Update CUDA flags

* Add py-jax 0.4.33, comment out until py-jaxlib 0.4.33 is also released

* Fix GCC build

* Try TF_NVCC_CLANG

* py-jax: add v0.4.34

* jax no longer has separate tags for jaxlib

* Install compiled wheel

* Join path before glob

* Wheel is in spack stage, not tmp path

* Add 0.4.35

* Add newer versions

* Build system has been refactored yet again

* Drop clang

* Fix build with source tarball, rocm support

* Support GCC

* Remove clang-specific compiler flags

* enable_cuda flag was removed

* Fix logic

* py-jax: add v0.4.38

* Add patch to fix GCC support

* Patch no longer needed

* Skip patching, directly pass flags

* New flags

* Remove unused import

* Patch changed

* Use older version of patch

* Newer patch

* Add CUDA symlink

* Symlink more directories

* Recursive symlink

* Import function

* Recursive search

* Undo cuda changes

* Add v0.5.0

* I quit
This commit is contained in:
Adam J. Stewart 2025-01-28 13:37:50 +01:00 committed by GitHub
parent 0d3667175a
commit 7a82c703c7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 122 additions and 74 deletions

View File

@ -7,22 +7,26 @@
class PyJax(PythonPackage):
"""JAX is Autograd and XLA, brought together for high-performance
machine learning research. With its updated version of Autograd,
JAX can automatically differentiate native Python and NumPy
functions. It can differentiate through loops, branches,
recursion, and closures, and it can take derivatives of
derivatives of derivatives. It supports reverse-mode
differentiation (a.k.a. backpropagation) via grad as well as
forward-mode differentiation, and the two can be composed
arbitrarily to any order."""
"""Differentiate, compile, and transform Numpy code.
homepage = "https://github.com/google/jax"
JAX is a Python library for accelerator-oriented array computation and program transformation,
designed for high-performance numerical computing and large-scale machine learning.
"""
homepage = "https://github.com/jax-ml/jax"
pypi = "jax/jax-0.4.27.tar.gz"
license("Apache-2.0")
maintainers("adamjstewart", "jonas-eschle")
# version("0.5.0", sha256="49df70bf293a345a7fb519f71193506d37a024c4f850b358042eb32d502c81c8")
# version("0.4.38", sha256="43bae65881628319e0a2148e8f81a202fbc2b8d048e35c7cb1df2416672fa4a8")
# version("0.4.37", sha256="7774f3d9e23fe199c65589c680c5a5be87a183b89598421a632d8245222b637b")
# version("0.4.36", sha256="088bff0575d01fc82682a9af4eb07433d60de7e5164686bd2cea3439492e608a")
# version("0.4.35", sha256="c0c986993026b10bf6f607fecb7417377460254640766ce40f1fef3fd139c12e")
# version("0.4.34", sha256="44196854f40c5f9cea3142824b9f1051f85afc3fcf7593ec5479fc8db01c58db")
# version("0.4.33", sha256="f0d788692fc0179653066c9e1c64e57311b8c15a389837fd7baf328abefcbb92")
# version("0.4.32", sha256="eb703909968da161894fb6135a931c5f3d2aab64fff7cba5fcb803ce6d968e08")
version("0.4.31", sha256="fd2d470643a0073d822737f0788f71391656af7e62cc5b2e7995ee390ceac287")
version("0.4.30", sha256="94d74b5b2db0d80672b61d83f1f63ebf99d2ab7398ec12b2ca0c9d1e97afe577")
version("0.4.29", sha256="12904571eaefddcdc8c3b8d4936482b783d5a216e99ef5adcd3522fdfb4fc186")
@ -59,9 +63,11 @@ class PyJax(PythonPackage):
# setup.py
depends_on("python@3.10:", when="@0.4.31:")
depends_on("python@3.9:", when="@0.4.14:")
depends_on("py-ml-dtypes@0.4:", when="@0.4.29,0.4.35:")
depends_on("py-ml-dtypes@0.2:", when="@0.4.14:")
depends_on("py-ml-dtypes@0.1:", when="@0.4.9:")
depends_on("py-ml-dtypes@0.0.3:", when="@0.4.7:")
depends_on("py-numpy@1.25:", when="@0.5:")
depends_on("py-numpy@1.24:", when="@0.4.31:")
depends_on("py-numpy@1.22:", when="@0.4.14:")
depends_on("py-numpy@1.21:", when="@0.4.7:")
@ -69,6 +75,7 @@ class PyJax(PythonPackage):
# https://github.com/google/jax/issues/19246
depends_on("py-numpy@:1", when="@:0.4.25")
depends_on("py-opt-einsum")
depends_on("py-scipy@1.11.1:", when="@0.5:")
depends_on("py-scipy@1.10:", when="@0.4.31:")
depends_on("py-scipy@1.9:", when="@0.4.19:")
depends_on("py-scipy@1.7:", when="@0.4.7:")
@ -77,6 +84,14 @@ class PyJax(PythonPackage):
# jax/_src/lib/__init__.py
# https://github.com/google/jax/commit/8be057de1f50756fe7522f7e98b2f30fad56f7e4
for v in [
# "0.5.0",
# "0.4.38",
# "0.4.37",
# "0.4.36",
# "0.4.35",
# "0.4.34",
# "0.4.33",
# "0.4.32",
"0.4.31",
"0.4.30",
"0.4.29",
@ -110,6 +125,13 @@ class PyJax(PythonPackage):
depends_on(f"py-jaxlib@:{v}", when=f"@{v}")
# See _minimum_jaxlib_version in jax/version.py
# depends_on("py-jaxlib@0.5:", when="@0.5:")
# depends_on("py-jaxlib@0.4.38:", when="@0.4.38:")
# depends_on("py-jaxlib@0.4.36:", when="@0.4.36:")
# depends_on("py-jaxlib@0.4.35:", when="@0.4.35:")
# depends_on("py-jaxlib@0.4.34:", when="@0.4.34:")
# depends_on("py-jaxlib@0.4.33:", when="@0.4.33:")
# depends_on("py-jaxlib@0.4.32:", when="@0.4.32:")
depends_on("py-jaxlib@0.4.30:", when="@0.4.31:")
depends_on("py-jaxlib@0.4.27:", when="@0.4.28:")
depends_on("py-jaxlib@0.4.23:", when="@0.4.27:")
@ -124,5 +146,4 @@ class PyJax(PythonPackage):
depends_on("py-jaxlib@0.4.1:", when="@0.4.2:")
# Historical dependencies
depends_on("py-ml-dtypes@0.4:", when="@0.4.29")
depends_on("py-importlib-metadata@4.6:", when="@0.4.11:0.4.30 ^python@:3.9")

View File

@ -2,7 +2,7 @@
#
# SPDX-License-Identifier: (Apache-2.0 OR MIT)
import tempfile
import glob
from spack.build_systems.python import PythonPipBuilder
from spack.package import *
@ -26,17 +26,27 @@
class PyJaxlib(PythonPackage, CudaPackage, ROCmPackage):
"""XLA library for Jax"""
"""XLA library for Jax.
homepage = "https://github.com/google/jax"
url = "https://github.com/google/jax/archive/refs/tags/jaxlib-v0.4.27.tar.gz"
jaxlib is the support library for JAX. While JAX itself is a pure Python package,
jaxlib contains the binary (C/C++) parts of the library, including Python bindings,
the XLA compiler, the PJRT runtime, and a handful of handwritten kernels.
"""
tmp_path = ""
buildtmp = ""
homepage = "https://github.com/jax-ml/jax"
url = "https://github.com/jax-ml/jax/archive/refs/tags/jax-v0.4.34.tar.gz"
license("Apache-2.0")
maintainers("adamjstewart", "jonas-eschle")
# version("0.5.0", sha256="04cc2eeb2e7ce1916674cea03a7d75a59d583ddb779d5104e103a2798a283ce9")
# version("0.4.38", sha256="ca1e63c488d505b9c92e81499e8b06cc1977319c50d64a0e58adbd2dae1a625c")
# version("0.4.37", sha256="17a8444a931f26edda8ccbc921ab71c6bf46857287b1db186deebd357e526870")
# version("0.4.36", sha256="442bfdf491b509995aa160361e23a9db488d5b97c87e6648cc733501b06eda77")
# version("0.4.35", sha256="65e086708ae56670676b7b2340ad82b901d8c9993d1241a839c8990bdb8d6212")
# version("0.4.34", sha256="d3a75ad667772309ade81350fa70c4a78028a920028800282e46d8383c0ee6bb")
# version("0.4.33", sha256="122a806e80fc1cd7d8ffaf9620701f2cb8e4fe22271c2cec53a9c60b30bd4c31")
# version("0.4.32", sha256="3fe36d596e4d640443c0a5c533845c74fbc4341e024d9bb1cd75cb49f5f419c2")
version("0.4.31", sha256="022ea1347f9b21cbea31410b3d650d976ea4452a48ea7317a5f91c238031bf94")
version("0.4.30", sha256="0ef9635c734d9bbb44fcc87df4f1c3ccce1cfcfd243572c80d36fcdf826fe1e6")
version("0.4.29", sha256="3a8005f4f62d35a5aad7e3dbd596890b47c81cc6e34fcfe3dcb93b3ca7cb1246")
@ -57,12 +67,12 @@ class PyJaxlib(PythonPackage, CudaPackage, ROCmPackage):
version("0.4.4", sha256="881f402c7983b56b185e182d5315dd64c9f5320be96213d0415996ece1826806")
version("0.4.3", sha256="2104735dc22be2b105e5517bd5bc6ae97f40e8e9e54928cac1585c6112a3d910")
depends_on("c", type="build")
depends_on("cxx", type="build")
variant("cuda", default=True, description="Build with CUDA enabled")
variant("nccl", default=True, description="Build with NCCL enabled", when="+cuda")
depends_on("c", type="build")
depends_on("cxx", type="build")
# docs/installation.md (Compatible with)
with when("+cuda"):
depends_on("cuda@12.1:", when="@0.4.26:")
@ -98,32 +108,42 @@ class PyJaxlib(PythonPackage, CudaPackage, ROCmPackage):
depends_on("py-build", when="@0.4.14:")
with default_args(type=("build", "run")):
# Based on PyPI wheels
depends_on("python@3.10:", when="@0.4.31:")
depends_on("python@3.9:", when="@0.4.14:")
depends_on("python@3.8:", when="@0.4.6:")
# Based on PyPI wheels
depends_on("python@:3.13")
depends_on("python@:3.12", when="@:0.4.33")
depends_on("python@:3.11", when="@:0.4.16")
# jaxlib/setup.py
depends_on("py-scipy@1.11.1:", when="@0.5:")
depends_on("py-scipy@1.10:", when="@0.4.31:")
depends_on("py-scipy@1.9:", when="@0.4.19:")
depends_on("py-scipy@1.7:", when="@0.4.7:")
depends_on("py-scipy@1.5:")
depends_on("py-numpy@1.25:", when="@0.5:")
depends_on("py-numpy@1.24:", when="@0.4.31:")
depends_on("py-numpy@1.22:", when="@0.4.14:")
depends_on("py-numpy@1.21:", when="@0.4.7:")
depends_on("py-numpy@1.20:", when="@0.3:")
# https://github.com/google/jax/issues/19246
depends_on("py-numpy@:1", when="@:0.4.25")
depends_on("py-ml-dtypes@0.4:", when="@0.4.29")
depends_on("py-ml-dtypes@0.2:", when="@0.4.14:")
depends_on("py-ml-dtypes@0.1:", when="@0.4.9:")
depends_on("py-ml-dtypes@0.0.3:", when="@0.4.7:")
# Historical dependencies
# https://github.com/google/jax/issues/19246
depends_on("py-numpy@:1", when="@:0.4.25")
depends_on("py-ml-dtypes@0.4:", when="@0.4.29")
patch(
"https://github.com/jax-ml/jax/commit/f62af6457a6cc575a7b1ada08d541f0dd0eb5765.patch?full_index=1",
sha256="d3b7ea2cfeba927e40a11f07e4cbf80939f7fe69448c9eb55231a93bd64e5c02",
when="@0.4.36:0.4.38",
)
patch(
"https://github.com/jax-ml/jax/pull/25473.patch?full_index=1",
sha256="9d6977bc32046600bf8b15863251283fe7546896340367a7f14e3dccf418b4fe",
when="@0.4.36:0.4.37",
)
patch(
"https://github.com/google/jax/pull/20101.patch?full_index=1",
sha256="4dfb9f32d4eeb0a0fb3a6f4124c4170e3fe49511f1b768cd634c78d489962275",
@ -144,58 +164,65 @@ class PyJaxlib(PythonPackage, CudaPackage, ROCmPackage):
# https://github.com/google/jax/issues/19992
conflicts("@0.4.4:", when="target=ppc64le:")
def patch(self):
self.tmp_path = tempfile.mkdtemp(prefix="spack")
self.buildtmp = tempfile.mkdtemp(prefix="spack")
filter_file(
"build --spawn_strategy=standalone",
f"""
# Limit CPU workers to spack jobs instead of using all HOST_CPUS.
build --spawn_strategy=standalone
build --local_cpu_resources={make_jobs}
""".strip(),
".bazelrc",
string=True,
)
filter_file(
'f"--output_path={output_path}",',
'f"--output_path={output_path}",'
f' "--sources_path={self.tmp_path}",'
' "--nohome_rc",'
' "--nosystem_rc",'
f' "--jobs={make_jobs}",',
"build/build.py",
string=True,
)
build_wheel = join_path("build", "build_wheel.py")
if self.spec.satisfies("@0.4.14:"):
build_wheel = join_path("jaxlib", "tools", "build_wheel.py")
filter_file(
"args = parser.parse_args()",
"args, junk = parser.parse_known_args()",
build_wheel,
string=True,
)
def url_for_version(self, version):
url = "https://github.com/jax-ml/jax/archive/refs/tags/{}-v{}.tar.gz"
if version >= Version("0.4.33"):
name = "jax"
else:
name = "jaxlib"
return url.format(name, version)
def install(self, spec, prefix):
args = []
args.append("build/build.py")
# https://jax.readthedocs.io/en/latest/developer.html
args = ["build/build.py"]
if spec.satisfies("@0.4.36:"):
args.append("build")
if spec.satisfies("+cuda"):
args.append("--wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt")
elif spec.satisfies("+rocm"):
args.append("--wheels=jaxlib,jax-rocm-plugin,jax-rocm-pjrt")
else:
args.append("--wheels=jaxlib")
if spec.satisfies("@0.4.32:"):
if spec.satisfies("%clang"):
args.append("--use_clang=true")
else:
args.append("--use_clang=false")
if "+cuda" in spec:
args.append("--enable_cuda")
args.append("--cuda_path={0}".format(self.spec["cuda"].prefix))
args.append("--cudnn_path={0}".format(self.spec["cudnn"].prefix))
capabilities = CudaPackage.compute_capabilities(spec.variants["cuda_arch"].value)
args.append("--cuda_compute_capabilities={0}".format(",".join(capabilities)))
args.append(
"--bazel_startup_options="
"--output_user_root={0}".format(self.wrapped_package_object.buildtmp)
args.append(f"--cuda_compute_capabilities={','.join(capabilities)}")
if spec.satisfies("@:0.4.35"):
args.append("--enable_cuda")
if spec.satisfies("@0.4.32:"):
args.extend(
[
f"--bazel_options=--repo_env=LOCAL_CUDA_PATH={spec['cuda'].prefix}",
f"--bazel_options=--repo_env=LOCAL_CUDNN_PATH={spec['cudnn'].prefix}",
]
)
else:
args.extend(
[f"--cuda_path={spec['cuda'].prefix}", f"--cudnn_path={spec['cudnn'].prefix}"]
)
if "+nccl" in spec and spec.satisfies("@0.4.32:"):
args.append(f"--bazel_options=--repo_env=LOCAL_NCCL_PATH={spec['nccl'].prefix}")
if "+rocm" in spec:
args.append("--enable_rocm")
args.append("--rocm_path={0}".format(self.spec["hip"].prefix))
args.extend(["--enable_rocm", f"--rocm_path={self.spec['hip'].prefix}"])
args.extend(
[
f"--bazel_options=--jobs={make_jobs}",
"--bazel_startup_options=--nohome_rc",
"--bazel_startup_options=--nosystem_rc",
]
)
python(*args)
with working_dir(self.wrapped_package_object.tmp_path):
pip(*PythonPipBuilder.std_args(self), f"--prefix={self.prefix}", ".")
remove_linked_tree(self.wrapped_package_object.tmp_path)
remove_linked_tree(self.wrapped_package_object.buildtmp)
whl = glob.glob(join_path("dist", "*.whl"))[0]
pip(*PythonPipBuilder.std_args(self), f"--prefix={self.prefix}", whl)