py-jaxlib: add spack-built ROCm support (#49611)

* py-jaxlib: add spack-built ROCm support

* fix style

* py-jaxlib 0.4.38 rocm support

* py-jaxlib 0.4.38 rocm support

* add comgr dependency

* changes for ROCm external and enable till 0.4.38

* enable version of py-jax

* add jax+rocm to ci

* add conflict for cuda and remove py-jaxlib from aarch64 pipeline

* Update var/spack/repos/builtin/packages/py-jaxlib/package.py

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>

* add conflict for aarch64

---------

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Afzal Patel 2025-03-26 11:23:52 -04:00 committed by GitHub
parent 145b0667cc
commit 2cd773aea4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 95 additions and 42 deletions

View File

@ -27,9 +27,8 @@ spack:
- py-transformers
# JAX
# Does not yet support Spack-installed ROCm
# - py-jax
# - py-jaxlib
- py-jax
- py-jaxlib
# Keras
- py-keras backend=tensorflow

View File

@ -20,13 +20,13 @@ class PyJax(PythonPackage):
maintainers("adamjstewart", "jonas-eschle")
# version("0.5.0", sha256="49df70bf293a345a7fb519f71193506d37a024c4f850b358042eb32d502c81c8")
# version("0.4.38", sha256="43bae65881628319e0a2148e8f81a202fbc2b8d048e35c7cb1df2416672fa4a8")
# version("0.4.37", sha256="7774f3d9e23fe199c65589c680c5a5be87a183b89598421a632d8245222b637b")
# version("0.4.36", sha256="088bff0575d01fc82682a9af4eb07433d60de7e5164686bd2cea3439492e608a")
# version("0.4.35", sha256="c0c986993026b10bf6f607fecb7417377460254640766ce40f1fef3fd139c12e")
# version("0.4.34", sha256="44196854f40c5f9cea3142824b9f1051f85afc3fcf7593ec5479fc8db01c58db")
# version("0.4.33", sha256="f0d788692fc0179653066c9e1c64e57311b8c15a389837fd7baf328abefcbb92")
# version("0.4.32", sha256="eb703909968da161894fb6135a931c5f3d2aab64fff7cba5fcb803ce6d968e08")
version("0.4.38", sha256="43bae65881628319e0a2148e8f81a202fbc2b8d048e35c7cb1df2416672fa4a8")
version("0.4.37", sha256="7774f3d9e23fe199c65589c680c5a5be87a183b89598421a632d8245222b637b")
version("0.4.36", sha256="088bff0575d01fc82682a9af4eb07433d60de7e5164686bd2cea3439492e608a")
version("0.4.35", sha256="c0c986993026b10bf6f607fecb7417377460254640766ce40f1fef3fd139c12e")
version("0.4.34", sha256="44196854f40c5f9cea3142824b9f1051f85afc3fcf7593ec5479fc8db01c58db")
version("0.4.33", sha256="f0d788692fc0179653066c9e1c64e57311b8c15a389837fd7baf328abefcbb92")
version("0.4.32", sha256="eb703909968da161894fb6135a931c5f3d2aab64fff7cba5fcb803ce6d968e08")
version("0.4.31", sha256="fd2d470643a0073d822737f0788f71391656af7e62cc5b2e7995ee390ceac287")
version("0.4.30", sha256="94d74b5b2db0d80672b61d83f1f63ebf99d2ab7398ec12b2ca0c9d1e97afe577")
version("0.4.29", sha256="12904571eaefddcdc8c3b8d4936482b783d5a216e99ef5adcd3522fdfb4fc186")
@ -85,13 +85,13 @@ class PyJax(PythonPackage):
# https://github.com/google/jax/commit/8be057de1f50756fe7522f7e98b2f30fad56f7e4
for v in [
# "0.5.0",
# "0.4.38",
# "0.4.37",
# "0.4.36",
# "0.4.35",
# "0.4.34",
# "0.4.33",
# "0.4.32",
"0.4.38",
"0.4.37",
"0.4.36",
"0.4.35",
"0.4.34",
"0.4.33",
"0.4.32",
"0.4.31",
"0.4.30",
"0.4.29",
@ -126,12 +126,12 @@ class PyJax(PythonPackage):
# See _minimum_jaxlib_version in jax/version.py
# depends_on("py-jaxlib@0.5:", when="@0.5:")
# depends_on("py-jaxlib@0.4.38:", when="@0.4.38:")
# depends_on("py-jaxlib@0.4.36:", when="@0.4.36:")
# depends_on("py-jaxlib@0.4.35:", when="@0.4.35:")
# depends_on("py-jaxlib@0.4.34:", when="@0.4.34:")
# depends_on("py-jaxlib@0.4.33:", when="@0.4.33:")
# depends_on("py-jaxlib@0.4.32:", when="@0.4.32:")
depends_on("py-jaxlib@0.4.38:", when="@0.4.38:")
depends_on("py-jaxlib@0.4.36:", when="@0.4.36:")
depends_on("py-jaxlib@0.4.35:", when="@0.4.35:")
depends_on("py-jaxlib@0.4.34:", when="@0.4.34:")
depends_on("py-jaxlib@0.4.33:", when="@0.4.33:")
depends_on("py-jaxlib@0.4.32:", when="@0.4.32:")
depends_on("py-jaxlib@0.4.30:", when="@0.4.31:")
depends_on("py-jaxlib@0.4.27:", when="@0.4.28:")
depends_on("py-jaxlib@0.4.23:", when="@0.4.27:")

View File

@ -8,20 +8,27 @@
from spack.package import *
rocm_dependencies = [
"hsa-rocr-dev",
"comgr",
"hip",
"rccl",
"rocprim",
"hipblas",
"hipblaslt",
"hipcub",
"rocthrust",
"roctracer-dev",
"rocrand",
"hipsparse",
"hipfft",
"rocfft",
"rocblas",
"hiprand",
"hipsolver",
"hipsparse",
"hsa-rocr-dev",
"miopen-hip",
"rccl",
"rocblas",
"rocfft",
"rocminfo",
"rocprim",
"rocrand",
"rocsolver",
"rocsparse",
"roctracer-dev",
"rocm-core",
]
@ -39,14 +46,17 @@ class PyJaxlib(PythonPackage, CudaPackage, ROCmPackage):
license("Apache-2.0")
maintainers("adamjstewart", "jonas-eschle")
# version("0.5.3", sha256="1094581a30ec069965f4e3e67d60262570cc3dd016adc62073bc24347b14270c")
# version("0.5.2", sha256="8e9de1e012dd65fc4a9eec8af4aa2bf6782767130a5d8e1c1e342b7d658280fe")
# version("0.5.1", sha256="e74b1209517682075933f757d646b73040d09fe39ee3e9e4cd398407dd0902d2")
# version("0.5.0", sha256="04cc2eeb2e7ce1916674cea03a7d75a59d583ddb779d5104e103a2798a283ce9")
# version("0.4.38", sha256="ca1e63c488d505b9c92e81499e8b06cc1977319c50d64a0e58adbd2dae1a625c")
# version("0.4.37", sha256="17a8444a931f26edda8ccbc921ab71c6bf46857287b1db186deebd357e526870")
# version("0.4.36", sha256="442bfdf491b509995aa160361e23a9db488d5b97c87e6648cc733501b06eda77")
# version("0.4.35", sha256="65e086708ae56670676b7b2340ad82b901d8c9993d1241a839c8990bdb8d6212")
# version("0.4.34", sha256="d3a75ad667772309ade81350fa70c4a78028a920028800282e46d8383c0ee6bb")
# version("0.4.33", sha256="122a806e80fc1cd7d8ffaf9620701f2cb8e4fe22271c2cec53a9c60b30bd4c31")
# version("0.4.32", sha256="3fe36d596e4d640443c0a5c533845c74fbc4341e024d9bb1cd75cb49f5f419c2")
version("0.4.38", sha256="ca1e63c488d505b9c92e81499e8b06cc1977319c50d64a0e58adbd2dae1a625c")
version("0.4.37", sha256="17a8444a931f26edda8ccbc921ab71c6bf46857287b1db186deebd357e526870")
version("0.4.36", sha256="442bfdf491b509995aa160361e23a9db488d5b97c87e6648cc733501b06eda77")
version("0.4.35", sha256="65e086708ae56670676b7b2340ad82b901d8c9993d1241a839c8990bdb8d6212")
version("0.4.34", sha256="d3a75ad667772309ade81350fa70c4a78028a920028800282e46d8383c0ee6bb")
version("0.4.33", sha256="122a806e80fc1cd7d8ffaf9620701f2cb8e4fe22271c2cec53a9c60b30bd4c31")
version("0.4.32", sha256="3fe36d596e4d640443c0a5c533845c74fbc4341e024d9bb1cd75cb49f5f419c2")
version("0.4.31", sha256="022ea1347f9b21cbea31410b3d650d976ea4452a48ea7317a5f91c238031bf94")
version("0.4.30", sha256="0ef9635c734d9bbb44fcc87df4f1c3ccce1cfcfd243572c80d36fcdf826fe1e6")
version("0.4.29", sha256="3a8005f4f62d35a5aad7e3dbd596890b47c81cc6e34fcfe3dcb93b3ca7cb1246")
@ -93,6 +103,10 @@ class PyJaxlib(PythonPackage, CudaPackage, ROCmPackage):
for pkg_dep in rocm_dependencies:
depends_on(f"{pkg_dep}@6:", when="@0.4.28:")
depends_on(pkg_dep)
depends_on("rocprofiler-register", when="^hip@6.2:")
depends_on("hipblas-common", when="^hip@6.3:")
depends_on("hsakmt-roct", when="^hip@:6.2")
depends_on("llvm-amdgpu")
depends_on("py-nanobind")
with default_args(type="build"):
@ -113,6 +127,7 @@ class PyJaxlib(PythonPackage, CudaPackage, ROCmPackage):
depends_on("python@3.9:", when="@0.4.14:")
depends_on("python@3.8:", when="@0.4.6:")
depends_on("python@:3.13")
depends_on("python@:3.12", when="+rocm")
depends_on("python@:3.12", when="@:0.4.33")
depends_on("python@:3.11", when="@:0.4.16")
@ -167,6 +182,22 @@ class PyJaxlib(PythonPackage, CudaPackage, ROCmPackage):
# Fails to build with freshly released CUDA (#48708).
conflicts("^cuda@12.8:", when="@:0.4.31")
# external CUDA is not supported https://github.com/jax-ml/jax/issues/23689
conflicts("+cuda", when="@0.4.32:")
# aarch64 is not supported https://github.com/jax-ml/jax/issues/25598
conflicts("target=aarch64:", when="@0.4.32:")
resource(
name="xla",
url="https://github.com/ROCm/xla/archive/07543ab117699a57c1267b453a62f89b1d5953fd.tar.gz",
sha256="cee377479654201c61cc3f230d89603cd589525fea2faf44564a23c70ba1448d",
expand=True,
destination="",
placement="xla",
when="@0.4.38:0.5.2 +rocm",
)
def url_for_version(self, version):
url = "https://github.com/jax-ml/jax/archive/refs/tags/{}-v{}.tar.gz"
if version >= Version("0.4.33"):
@ -175,6 +206,20 @@ def url_for_version(self, version):
name = "jaxlib"
return url.format(name, version)
def setup_build_environment(self, env):
spec = self.spec
if spec.satisfies("@0.4.38: +rocm") and not spec["hip"].external:
if spec.satisfies("^hip@6.2:"):
rocm_dependencies.append("rocprofiler-register")
if spec.satisfies("^hip@6.3:"):
rocm_dependencies.append("hipblas-common")
else:
rocm_dependencies.append("hsakmt-roct")
env.set("LLVM_PATH", spec["llvm-amdgpu"].prefix)
for pkg_dep in rocm_dependencies:
env.prepend_path("TF_ROCM_MULTIPLE_PATHS", spec[pkg_dep].prefix)
env.prune_duplicate_paths("TF_ROCM_MULTIPLE_PATHS")
def install(self, spec, prefix):
# https://jax.readthedocs.io/en/latest/developer.html
args = ["build/build.py"]
@ -216,7 +261,15 @@ def install(self, spec, prefix):
args.append(f"--bazel_options=--repo_env=LOCAL_NCCL_PATH={spec['nccl'].prefix}")
if "+rocm" in spec:
args.extend(["--enable_rocm", f"--rocm_path={self.spec['hip'].prefix}"])
args.append(f"--rocm_path={self.spec['hip'].prefix}")
if spec.satisfies("@:0.4.35"):
args.append("--enable_rocm")
if spec.satisfies("@0.4.38:") and not spec["hip"].external:
args.append("--bazel_options=--@local_config_rocm//rocm:rocm_path_type=multiple")
if spec.satisfies("@0.4.38:0.5.2"):
args.append(
f"--bazel_options=--override_repository=xla={self.stage.source_path}/xla"
)
args.extend(
[
@ -227,5 +280,6 @@ def install(self, spec, prefix):
)
python(*args)
whl = glob.glob(join_path("dist", "*.whl"))[0]
pip(*PythonPipBuilder.std_args(self), f"--prefix={self.prefix}", whl)
for whl in glob.glob(join_path("dist", "*.whl")):
pip(*PythonPipBuilder.std_args(self), f"--prefix={self.prefix}", whl)