JAX: add v0.4.32+ (#46346)
* JAX: add v0.4.34 * Disable search for clang * Update CUDA flags * Add py-jax 0.4.33, comment out until py-jaxlib 0.4.33 is also released * Fix GCC build * Try TF_NVCC_CLANG * py-jax: add v0.4.34 * jax no longer has separate tags for jaxlib * Install compiled wheel * Join path before glob * Wheel is in spack stage, not tmp path * Add 0.4.35 * Add newer versions * Build system has been refactored yet again * Drop clang * Fix build with source tarball, rocm support * Support GCC * Remove clang-specific compiler flags * enable_cuda flag was removed * Fix logic * py-jax: add v0.4.38 * Add patch to fix GCC support * Patch no longer needed * Skip patching, directly pass flags * New flags * Remove unused import * Patch changed * Use older version of patch * Newer patch * Add CUDA symlink * Symlink more directories * Recursive symlink * Import function * Recursive search * Undo cuda changes * Add v0.5.0 * I quit
This commit is contained in:
parent
0d3667175a
commit
7a82c703c7
@ -7,22 +7,26 @@
|
||||
|
||||
|
||||
class PyJax(PythonPackage):
|
||||
"""JAX is Autograd and XLA, brought together for high-performance
|
||||
machine learning research. With its updated version of Autograd,
|
||||
JAX can automatically differentiate native Python and NumPy
|
||||
functions. It can differentiate through loops, branches,
|
||||
recursion, and closures, and it can take derivatives of
|
||||
derivatives of derivatives. It supports reverse-mode
|
||||
differentiation (a.k.a. backpropagation) via grad as well as
|
||||
forward-mode differentiation, and the two can be composed
|
||||
arbitrarily to any order."""
|
||||
"""Differentiate, compile, and transform Numpy code.
|
||||
|
||||
homepage = "https://github.com/google/jax"
|
||||
JAX is a Python library for accelerator-oriented array computation and program transformation,
|
||||
designed for high-performance numerical computing and large-scale machine learning.
|
||||
"""
|
||||
|
||||
homepage = "https://github.com/jax-ml/jax"
|
||||
pypi = "jax/jax-0.4.27.tar.gz"
|
||||
|
||||
license("Apache-2.0")
|
||||
maintainers("adamjstewart", "jonas-eschle")
|
||||
|
||||
# version("0.5.0", sha256="49df70bf293a345a7fb519f71193506d37a024c4f850b358042eb32d502c81c8")
|
||||
# version("0.4.38", sha256="43bae65881628319e0a2148e8f81a202fbc2b8d048e35c7cb1df2416672fa4a8")
|
||||
# version("0.4.37", sha256="7774f3d9e23fe199c65589c680c5a5be87a183b89598421a632d8245222b637b")
|
||||
# version("0.4.36", sha256="088bff0575d01fc82682a9af4eb07433d60de7e5164686bd2cea3439492e608a")
|
||||
# version("0.4.35", sha256="c0c986993026b10bf6f607fecb7417377460254640766ce40f1fef3fd139c12e")
|
||||
# version("0.4.34", sha256="44196854f40c5f9cea3142824b9f1051f85afc3fcf7593ec5479fc8db01c58db")
|
||||
# version("0.4.33", sha256="f0d788692fc0179653066c9e1c64e57311b8c15a389837fd7baf328abefcbb92")
|
||||
# version("0.4.32", sha256="eb703909968da161894fb6135a931c5f3d2aab64fff7cba5fcb803ce6d968e08")
|
||||
version("0.4.31", sha256="fd2d470643a0073d822737f0788f71391656af7e62cc5b2e7995ee390ceac287")
|
||||
version("0.4.30", sha256="94d74b5b2db0d80672b61d83f1f63ebf99d2ab7398ec12b2ca0c9d1e97afe577")
|
||||
version("0.4.29", sha256="12904571eaefddcdc8c3b8d4936482b783d5a216e99ef5adcd3522fdfb4fc186")
|
||||
@ -59,9 +63,11 @@ class PyJax(PythonPackage):
|
||||
# setup.py
|
||||
depends_on("python@3.10:", when="@0.4.31:")
|
||||
depends_on("python@3.9:", when="@0.4.14:")
|
||||
depends_on("py-ml-dtypes@0.4:", when="@0.4.29,0.4.35:")
|
||||
depends_on("py-ml-dtypes@0.2:", when="@0.4.14:")
|
||||
depends_on("py-ml-dtypes@0.1:", when="@0.4.9:")
|
||||
depends_on("py-ml-dtypes@0.0.3:", when="@0.4.7:")
|
||||
depends_on("py-numpy@1.25:", when="@0.5:")
|
||||
depends_on("py-numpy@1.24:", when="@0.4.31:")
|
||||
depends_on("py-numpy@1.22:", when="@0.4.14:")
|
||||
depends_on("py-numpy@1.21:", when="@0.4.7:")
|
||||
@ -69,6 +75,7 @@ class PyJax(PythonPackage):
|
||||
# https://github.com/google/jax/issues/19246
|
||||
depends_on("py-numpy@:1", when="@:0.4.25")
|
||||
depends_on("py-opt-einsum")
|
||||
depends_on("py-scipy@1.11.1:", when="@0.5:")
|
||||
depends_on("py-scipy@1.10:", when="@0.4.31:")
|
||||
depends_on("py-scipy@1.9:", when="@0.4.19:")
|
||||
depends_on("py-scipy@1.7:", when="@0.4.7:")
|
||||
@ -77,6 +84,14 @@ class PyJax(PythonPackage):
|
||||
# jax/_src/lib/__init__.py
|
||||
# https://github.com/google/jax/commit/8be057de1f50756fe7522f7e98b2f30fad56f7e4
|
||||
for v in [
|
||||
# "0.5.0",
|
||||
# "0.4.38",
|
||||
# "0.4.37",
|
||||
# "0.4.36",
|
||||
# "0.4.35",
|
||||
# "0.4.34",
|
||||
# "0.4.33",
|
||||
# "0.4.32",
|
||||
"0.4.31",
|
||||
"0.4.30",
|
||||
"0.4.29",
|
||||
@ -110,6 +125,13 @@ class PyJax(PythonPackage):
|
||||
depends_on(f"py-jaxlib@:{v}", when=f"@{v}")
|
||||
|
||||
# See _minimum_jaxlib_version in jax/version.py
|
||||
# depends_on("py-jaxlib@0.5:", when="@0.5:")
|
||||
# depends_on("py-jaxlib@0.4.38:", when="@0.4.38:")
|
||||
# depends_on("py-jaxlib@0.4.36:", when="@0.4.36:")
|
||||
# depends_on("py-jaxlib@0.4.35:", when="@0.4.35:")
|
||||
# depends_on("py-jaxlib@0.4.34:", when="@0.4.34:")
|
||||
# depends_on("py-jaxlib@0.4.33:", when="@0.4.33:")
|
||||
# depends_on("py-jaxlib@0.4.32:", when="@0.4.32:")
|
||||
depends_on("py-jaxlib@0.4.30:", when="@0.4.31:")
|
||||
depends_on("py-jaxlib@0.4.27:", when="@0.4.28:")
|
||||
depends_on("py-jaxlib@0.4.23:", when="@0.4.27:")
|
||||
@ -124,5 +146,4 @@ class PyJax(PythonPackage):
|
||||
depends_on("py-jaxlib@0.4.1:", when="@0.4.2:")
|
||||
|
||||
# Historical dependencies
|
||||
depends_on("py-ml-dtypes@0.4:", when="@0.4.29")
|
||||
depends_on("py-importlib-metadata@4.6:", when="@0.4.11:0.4.30 ^python@:3.9")
|
||||
|
@ -2,7 +2,7 @@
|
||||
#
|
||||
# SPDX-License-Identifier: (Apache-2.0 OR MIT)
|
||||
|
||||
import tempfile
|
||||
import glob
|
||||
|
||||
from spack.build_systems.python import PythonPipBuilder
|
||||
from spack.package import *
|
||||
@ -26,17 +26,27 @@
|
||||
|
||||
|
||||
class PyJaxlib(PythonPackage, CudaPackage, ROCmPackage):
|
||||
"""XLA library for Jax"""
|
||||
"""XLA library for Jax.
|
||||
|
||||
homepage = "https://github.com/google/jax"
|
||||
url = "https://github.com/google/jax/archive/refs/tags/jaxlib-v0.4.27.tar.gz"
|
||||
jaxlib is the support library for JAX. While JAX itself is a pure Python package,
|
||||
jaxlib contains the binary (C/C++) parts of the library, including Python bindings,
|
||||
the XLA compiler, the PJRT runtime, and a handful of handwritten kernels.
|
||||
"""
|
||||
|
||||
tmp_path = ""
|
||||
buildtmp = ""
|
||||
homepage = "https://github.com/jax-ml/jax"
|
||||
url = "https://github.com/jax-ml/jax/archive/refs/tags/jax-v0.4.34.tar.gz"
|
||||
|
||||
license("Apache-2.0")
|
||||
maintainers("adamjstewart", "jonas-eschle")
|
||||
|
||||
# version("0.5.0", sha256="04cc2eeb2e7ce1916674cea03a7d75a59d583ddb779d5104e103a2798a283ce9")
|
||||
# version("0.4.38", sha256="ca1e63c488d505b9c92e81499e8b06cc1977319c50d64a0e58adbd2dae1a625c")
|
||||
# version("0.4.37", sha256="17a8444a931f26edda8ccbc921ab71c6bf46857287b1db186deebd357e526870")
|
||||
# version("0.4.36", sha256="442bfdf491b509995aa160361e23a9db488d5b97c87e6648cc733501b06eda77")
|
||||
# version("0.4.35", sha256="65e086708ae56670676b7b2340ad82b901d8c9993d1241a839c8990bdb8d6212")
|
||||
# version("0.4.34", sha256="d3a75ad667772309ade81350fa70c4a78028a920028800282e46d8383c0ee6bb")
|
||||
# version("0.4.33", sha256="122a806e80fc1cd7d8ffaf9620701f2cb8e4fe22271c2cec53a9c60b30bd4c31")
|
||||
# version("0.4.32", sha256="3fe36d596e4d640443c0a5c533845c74fbc4341e024d9bb1cd75cb49f5f419c2")
|
||||
version("0.4.31", sha256="022ea1347f9b21cbea31410b3d650d976ea4452a48ea7317a5f91c238031bf94")
|
||||
version("0.4.30", sha256="0ef9635c734d9bbb44fcc87df4f1c3ccce1cfcfd243572c80d36fcdf826fe1e6")
|
||||
version("0.4.29", sha256="3a8005f4f62d35a5aad7e3dbd596890b47c81cc6e34fcfe3dcb93b3ca7cb1246")
|
||||
@ -57,12 +67,12 @@ class PyJaxlib(PythonPackage, CudaPackage, ROCmPackage):
|
||||
version("0.4.4", sha256="881f402c7983b56b185e182d5315dd64c9f5320be96213d0415996ece1826806")
|
||||
version("0.4.3", sha256="2104735dc22be2b105e5517bd5bc6ae97f40e8e9e54928cac1585c6112a3d910")
|
||||
|
||||
depends_on("c", type="build")
|
||||
depends_on("cxx", type="build")
|
||||
|
||||
variant("cuda", default=True, description="Build with CUDA enabled")
|
||||
variant("nccl", default=True, description="Build with NCCL enabled", when="+cuda")
|
||||
|
||||
depends_on("c", type="build")
|
||||
depends_on("cxx", type="build")
|
||||
|
||||
# docs/installation.md (Compatible with)
|
||||
with when("+cuda"):
|
||||
depends_on("cuda@12.1:", when="@0.4.26:")
|
||||
@ -98,32 +108,42 @@ class PyJaxlib(PythonPackage, CudaPackage, ROCmPackage):
|
||||
depends_on("py-build", when="@0.4.14:")
|
||||
|
||||
with default_args(type=("build", "run")):
|
||||
# Based on PyPI wheels
|
||||
depends_on("python@3.10:", when="@0.4.31:")
|
||||
depends_on("python@3.9:", when="@0.4.14:")
|
||||
depends_on("python@3.8:", when="@0.4.6:")
|
||||
# Based on PyPI wheels
|
||||
depends_on("python@:3.13")
|
||||
depends_on("python@:3.12", when="@:0.4.33")
|
||||
depends_on("python@:3.11", when="@:0.4.16")
|
||||
|
||||
# jaxlib/setup.py
|
||||
depends_on("py-scipy@1.11.1:", when="@0.5:")
|
||||
depends_on("py-scipy@1.10:", when="@0.4.31:")
|
||||
depends_on("py-scipy@1.9:", when="@0.4.19:")
|
||||
depends_on("py-scipy@1.7:", when="@0.4.7:")
|
||||
depends_on("py-scipy@1.5:")
|
||||
depends_on("py-numpy@1.25:", when="@0.5:")
|
||||
depends_on("py-numpy@1.24:", when="@0.4.31:")
|
||||
depends_on("py-numpy@1.22:", when="@0.4.14:")
|
||||
depends_on("py-numpy@1.21:", when="@0.4.7:")
|
||||
depends_on("py-numpy@1.20:", when="@0.3:")
|
||||
# https://github.com/google/jax/issues/19246
|
||||
depends_on("py-numpy@:1", when="@:0.4.25")
|
||||
depends_on("py-ml-dtypes@0.4:", when="@0.4.29")
|
||||
depends_on("py-ml-dtypes@0.2:", when="@0.4.14:")
|
||||
depends_on("py-ml-dtypes@0.1:", when="@0.4.9:")
|
||||
depends_on("py-ml-dtypes@0.0.3:", when="@0.4.7:")
|
||||
|
||||
# Historical dependencies
|
||||
# https://github.com/google/jax/issues/19246
|
||||
depends_on("py-numpy@:1", when="@:0.4.25")
|
||||
depends_on("py-ml-dtypes@0.4:", when="@0.4.29")
|
||||
|
||||
patch(
|
||||
"https://github.com/jax-ml/jax/commit/f62af6457a6cc575a7b1ada08d541f0dd0eb5765.patch?full_index=1",
|
||||
sha256="d3b7ea2cfeba927e40a11f07e4cbf80939f7fe69448c9eb55231a93bd64e5c02",
|
||||
when="@0.4.36:0.4.38",
|
||||
)
|
||||
patch(
|
||||
"https://github.com/jax-ml/jax/pull/25473.patch?full_index=1",
|
||||
sha256="9d6977bc32046600bf8b15863251283fe7546896340367a7f14e3dccf418b4fe",
|
||||
when="@0.4.36:0.4.37",
|
||||
)
|
||||
patch(
|
||||
"https://github.com/google/jax/pull/20101.patch?full_index=1",
|
||||
sha256="4dfb9f32d4eeb0a0fb3a6f4124c4170e3fe49511f1b768cd634c78d489962275",
|
||||
@ -144,58 +164,65 @@ class PyJaxlib(PythonPackage, CudaPackage, ROCmPackage):
|
||||
# https://github.com/google/jax/issues/19992
|
||||
conflicts("@0.4.4:", when="target=ppc64le:")
|
||||
|
||||
def patch(self):
|
||||
self.tmp_path = tempfile.mkdtemp(prefix="spack")
|
||||
self.buildtmp = tempfile.mkdtemp(prefix="spack")
|
||||
filter_file(
|
||||
"build --spawn_strategy=standalone",
|
||||
f"""
|
||||
# Limit CPU workers to spack jobs instead of using all HOST_CPUS.
|
||||
build --spawn_strategy=standalone
|
||||
build --local_cpu_resources={make_jobs}
|
||||
""".strip(),
|
||||
".bazelrc",
|
||||
string=True,
|
||||
)
|
||||
filter_file(
|
||||
'f"--output_path={output_path}",',
|
||||
'f"--output_path={output_path}",'
|
||||
f' "--sources_path={self.tmp_path}",'
|
||||
' "--nohome_rc",'
|
||||
' "--nosystem_rc",'
|
||||
f' "--jobs={make_jobs}",',
|
||||
"build/build.py",
|
||||
string=True,
|
||||
)
|
||||
build_wheel = join_path("build", "build_wheel.py")
|
||||
if self.spec.satisfies("@0.4.14:"):
|
||||
build_wheel = join_path("jaxlib", "tools", "build_wheel.py")
|
||||
filter_file(
|
||||
"args = parser.parse_args()",
|
||||
"args, junk = parser.parse_known_args()",
|
||||
build_wheel,
|
||||
string=True,
|
||||
)
|
||||
def url_for_version(self, version):
|
||||
url = "https://github.com/jax-ml/jax/archive/refs/tags/{}-v{}.tar.gz"
|
||||
if version >= Version("0.4.33"):
|
||||
name = "jax"
|
||||
else:
|
||||
name = "jaxlib"
|
||||
return url.format(name, version)
|
||||
|
||||
def install(self, spec, prefix):
|
||||
args = []
|
||||
args.append("build/build.py")
|
||||
# https://jax.readthedocs.io/en/latest/developer.html
|
||||
args = ["build/build.py"]
|
||||
|
||||
if spec.satisfies("@0.4.36:"):
|
||||
args.append("build")
|
||||
|
||||
if spec.satisfies("+cuda"):
|
||||
args.append("--wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt")
|
||||
elif spec.satisfies("+rocm"):
|
||||
args.append("--wheels=jaxlib,jax-rocm-plugin,jax-rocm-pjrt")
|
||||
else:
|
||||
args.append("--wheels=jaxlib")
|
||||
|
||||
if spec.satisfies("@0.4.32:"):
|
||||
if spec.satisfies("%clang"):
|
||||
args.append("--use_clang=true")
|
||||
else:
|
||||
args.append("--use_clang=false")
|
||||
|
||||
if "+cuda" in spec:
|
||||
args.append("--enable_cuda")
|
||||
args.append("--cuda_path={0}".format(self.spec["cuda"].prefix))
|
||||
args.append("--cudnn_path={0}".format(self.spec["cudnn"].prefix))
|
||||
capabilities = CudaPackage.compute_capabilities(spec.variants["cuda_arch"].value)
|
||||
args.append("--cuda_compute_capabilities={0}".format(",".join(capabilities)))
|
||||
args.append(
|
||||
"--bazel_startup_options="
|
||||
"--output_user_root={0}".format(self.wrapped_package_object.buildtmp)
|
||||
args.append(f"--cuda_compute_capabilities={','.join(capabilities)}")
|
||||
if spec.satisfies("@:0.4.35"):
|
||||
args.append("--enable_cuda")
|
||||
if spec.satisfies("@0.4.32:"):
|
||||
args.extend(
|
||||
[
|
||||
f"--bazel_options=--repo_env=LOCAL_CUDA_PATH={spec['cuda'].prefix}",
|
||||
f"--bazel_options=--repo_env=LOCAL_CUDNN_PATH={spec['cudnn'].prefix}",
|
||||
]
|
||||
)
|
||||
else:
|
||||
args.extend(
|
||||
[f"--cuda_path={spec['cuda'].prefix}", f"--cudnn_path={spec['cudnn'].prefix}"]
|
||||
)
|
||||
|
||||
if "+nccl" in spec and spec.satisfies("@0.4.32:"):
|
||||
args.append(f"--bazel_options=--repo_env=LOCAL_NCCL_PATH={spec['nccl'].prefix}")
|
||||
|
||||
if "+rocm" in spec:
|
||||
args.append("--enable_rocm")
|
||||
args.append("--rocm_path={0}".format(self.spec["hip"].prefix))
|
||||
args.extend(["--enable_rocm", f"--rocm_path={self.spec['hip'].prefix}"])
|
||||
|
||||
args.extend(
|
||||
[
|
||||
f"--bazel_options=--jobs={make_jobs}",
|
||||
"--bazel_startup_options=--nohome_rc",
|
||||
"--bazel_startup_options=--nosystem_rc",
|
||||
]
|
||||
)
|
||||
|
||||
python(*args)
|
||||
with working_dir(self.wrapped_package_object.tmp_path):
|
||||
pip(*PythonPipBuilder.std_args(self), f"--prefix={self.prefix}", ".")
|
||||
remove_linked_tree(self.wrapped_package_object.tmp_path)
|
||||
remove_linked_tree(self.wrapped_package_object.buildtmp)
|
||||
whl = glob.glob(join_path("dist", "*.whl"))[0]
|
||||
pip(*PythonPipBuilder.std_args(self), f"--prefix={self.prefix}", whl)
|
||||
|
Loading…
Reference in New Issue
Block a user