py-jaxlib: add spack-built ROCm support (#49611)
* py-jaxlib: add spack-built ROCm support * fix style * py-jaxlib 0.4.38 rocm support * py-jaxlib 0.4.38 rocm support * add comgr dependency * changes for ROCm external and enable till 0.4.38 * enable version of py-jax * add jax+rocm to ci * add conflict for cuda and remove py-jaxlib from aarch64 pipeline * Update var/spack/repos/builtin/packages/py-jaxlib/package.py Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * add conflict for aarch64 --------- Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
parent
145b0667cc
commit
2cd773aea4
@ -27,9 +27,8 @@ spack:
|
||||
- py-transformers
|
||||
|
||||
# JAX
|
||||
# Does not yet support Spack-installed ROCm
|
||||
# - py-jax
|
||||
# - py-jaxlib
|
||||
- py-jax
|
||||
- py-jaxlib
|
||||
|
||||
# Keras
|
||||
- py-keras backend=tensorflow
|
||||
|
@ -20,13 +20,13 @@ class PyJax(PythonPackage):
|
||||
maintainers("adamjstewart", "jonas-eschle")
|
||||
|
||||
# version("0.5.0", sha256="49df70bf293a345a7fb519f71193506d37a024c4f850b358042eb32d502c81c8")
|
||||
# version("0.4.38", sha256="43bae65881628319e0a2148e8f81a202fbc2b8d048e35c7cb1df2416672fa4a8")
|
||||
# version("0.4.37", sha256="7774f3d9e23fe199c65589c680c5a5be87a183b89598421a632d8245222b637b")
|
||||
# version("0.4.36", sha256="088bff0575d01fc82682a9af4eb07433d60de7e5164686bd2cea3439492e608a")
|
||||
# version("0.4.35", sha256="c0c986993026b10bf6f607fecb7417377460254640766ce40f1fef3fd139c12e")
|
||||
# version("0.4.34", sha256="44196854f40c5f9cea3142824b9f1051f85afc3fcf7593ec5479fc8db01c58db")
|
||||
# version("0.4.33", sha256="f0d788692fc0179653066c9e1c64e57311b8c15a389837fd7baf328abefcbb92")
|
||||
# version("0.4.32", sha256="eb703909968da161894fb6135a931c5f3d2aab64fff7cba5fcb803ce6d968e08")
|
||||
version("0.4.38", sha256="43bae65881628319e0a2148e8f81a202fbc2b8d048e35c7cb1df2416672fa4a8")
|
||||
version("0.4.37", sha256="7774f3d9e23fe199c65589c680c5a5be87a183b89598421a632d8245222b637b")
|
||||
version("0.4.36", sha256="088bff0575d01fc82682a9af4eb07433d60de7e5164686bd2cea3439492e608a")
|
||||
version("0.4.35", sha256="c0c986993026b10bf6f607fecb7417377460254640766ce40f1fef3fd139c12e")
|
||||
version("0.4.34", sha256="44196854f40c5f9cea3142824b9f1051f85afc3fcf7593ec5479fc8db01c58db")
|
||||
version("0.4.33", sha256="f0d788692fc0179653066c9e1c64e57311b8c15a389837fd7baf328abefcbb92")
|
||||
version("0.4.32", sha256="eb703909968da161894fb6135a931c5f3d2aab64fff7cba5fcb803ce6d968e08")
|
||||
version("0.4.31", sha256="fd2d470643a0073d822737f0788f71391656af7e62cc5b2e7995ee390ceac287")
|
||||
version("0.4.30", sha256="94d74b5b2db0d80672b61d83f1f63ebf99d2ab7398ec12b2ca0c9d1e97afe577")
|
||||
version("0.4.29", sha256="12904571eaefddcdc8c3b8d4936482b783d5a216e99ef5adcd3522fdfb4fc186")
|
||||
@ -85,13 +85,13 @@ class PyJax(PythonPackage):
|
||||
# https://github.com/google/jax/commit/8be057de1f50756fe7522f7e98b2f30fad56f7e4
|
||||
for v in [
|
||||
# "0.5.0",
|
||||
# "0.4.38",
|
||||
# "0.4.37",
|
||||
# "0.4.36",
|
||||
# "0.4.35",
|
||||
# "0.4.34",
|
||||
# "0.4.33",
|
||||
# "0.4.32",
|
||||
"0.4.38",
|
||||
"0.4.37",
|
||||
"0.4.36",
|
||||
"0.4.35",
|
||||
"0.4.34",
|
||||
"0.4.33",
|
||||
"0.4.32",
|
||||
"0.4.31",
|
||||
"0.4.30",
|
||||
"0.4.29",
|
||||
@ -126,12 +126,12 @@ class PyJax(PythonPackage):
|
||||
|
||||
# See _minimum_jaxlib_version in jax/version.py
|
||||
# depends_on("py-jaxlib@0.5:", when="@0.5:")
|
||||
# depends_on("py-jaxlib@0.4.38:", when="@0.4.38:")
|
||||
# depends_on("py-jaxlib@0.4.36:", when="@0.4.36:")
|
||||
# depends_on("py-jaxlib@0.4.35:", when="@0.4.35:")
|
||||
# depends_on("py-jaxlib@0.4.34:", when="@0.4.34:")
|
||||
# depends_on("py-jaxlib@0.4.33:", when="@0.4.33:")
|
||||
# depends_on("py-jaxlib@0.4.32:", when="@0.4.32:")
|
||||
depends_on("py-jaxlib@0.4.38:", when="@0.4.38:")
|
||||
depends_on("py-jaxlib@0.4.36:", when="@0.4.36:")
|
||||
depends_on("py-jaxlib@0.4.35:", when="@0.4.35:")
|
||||
depends_on("py-jaxlib@0.4.34:", when="@0.4.34:")
|
||||
depends_on("py-jaxlib@0.4.33:", when="@0.4.33:")
|
||||
depends_on("py-jaxlib@0.4.32:", when="@0.4.32:")
|
||||
depends_on("py-jaxlib@0.4.30:", when="@0.4.31:")
|
||||
depends_on("py-jaxlib@0.4.27:", when="@0.4.28:")
|
||||
depends_on("py-jaxlib@0.4.23:", when="@0.4.27:")
|
||||
|
@ -8,20 +8,27 @@
|
||||
from spack.package import *
|
||||
|
||||
rocm_dependencies = [
|
||||
"hsa-rocr-dev",
|
||||
"comgr",
|
||||
"hip",
|
||||
"rccl",
|
||||
"rocprim",
|
||||
"hipblas",
|
||||
"hipblaslt",
|
||||
"hipcub",
|
||||
"rocthrust",
|
||||
"roctracer-dev",
|
||||
"rocrand",
|
||||
"hipsparse",
|
||||
"hipfft",
|
||||
"rocfft",
|
||||
"rocblas",
|
||||
"hiprand",
|
||||
"hipsolver",
|
||||
"hipsparse",
|
||||
"hsa-rocr-dev",
|
||||
"miopen-hip",
|
||||
"rccl",
|
||||
"rocblas",
|
||||
"rocfft",
|
||||
"rocminfo",
|
||||
"rocprim",
|
||||
"rocrand",
|
||||
"rocsolver",
|
||||
"rocsparse",
|
||||
"roctracer-dev",
|
||||
"rocm-core",
|
||||
]
|
||||
|
||||
|
||||
@ -39,14 +46,17 @@ class PyJaxlib(PythonPackage, CudaPackage, ROCmPackage):
|
||||
license("Apache-2.0")
|
||||
maintainers("adamjstewart", "jonas-eschle")
|
||||
|
||||
# version("0.5.3", sha256="1094581a30ec069965f4e3e67d60262570cc3dd016adc62073bc24347b14270c")
|
||||
# version("0.5.2", sha256="8e9de1e012dd65fc4a9eec8af4aa2bf6782767130a5d8e1c1e342b7d658280fe")
|
||||
# version("0.5.1", sha256="e74b1209517682075933f757d646b73040d09fe39ee3e9e4cd398407dd0902d2")
|
||||
# version("0.5.0", sha256="04cc2eeb2e7ce1916674cea03a7d75a59d583ddb779d5104e103a2798a283ce9")
|
||||
# version("0.4.38", sha256="ca1e63c488d505b9c92e81499e8b06cc1977319c50d64a0e58adbd2dae1a625c")
|
||||
# version("0.4.37", sha256="17a8444a931f26edda8ccbc921ab71c6bf46857287b1db186deebd357e526870")
|
||||
# version("0.4.36", sha256="442bfdf491b509995aa160361e23a9db488d5b97c87e6648cc733501b06eda77")
|
||||
# version("0.4.35", sha256="65e086708ae56670676b7b2340ad82b901d8c9993d1241a839c8990bdb8d6212")
|
||||
# version("0.4.34", sha256="d3a75ad667772309ade81350fa70c4a78028a920028800282e46d8383c0ee6bb")
|
||||
# version("0.4.33", sha256="122a806e80fc1cd7d8ffaf9620701f2cb8e4fe22271c2cec53a9c60b30bd4c31")
|
||||
# version("0.4.32", sha256="3fe36d596e4d640443c0a5c533845c74fbc4341e024d9bb1cd75cb49f5f419c2")
|
||||
version("0.4.38", sha256="ca1e63c488d505b9c92e81499e8b06cc1977319c50d64a0e58adbd2dae1a625c")
|
||||
version("0.4.37", sha256="17a8444a931f26edda8ccbc921ab71c6bf46857287b1db186deebd357e526870")
|
||||
version("0.4.36", sha256="442bfdf491b509995aa160361e23a9db488d5b97c87e6648cc733501b06eda77")
|
||||
version("0.4.35", sha256="65e086708ae56670676b7b2340ad82b901d8c9993d1241a839c8990bdb8d6212")
|
||||
version("0.4.34", sha256="d3a75ad667772309ade81350fa70c4a78028a920028800282e46d8383c0ee6bb")
|
||||
version("0.4.33", sha256="122a806e80fc1cd7d8ffaf9620701f2cb8e4fe22271c2cec53a9c60b30bd4c31")
|
||||
version("0.4.32", sha256="3fe36d596e4d640443c0a5c533845c74fbc4341e024d9bb1cd75cb49f5f419c2")
|
||||
version("0.4.31", sha256="022ea1347f9b21cbea31410b3d650d976ea4452a48ea7317a5f91c238031bf94")
|
||||
version("0.4.30", sha256="0ef9635c734d9bbb44fcc87df4f1c3ccce1cfcfd243572c80d36fcdf826fe1e6")
|
||||
version("0.4.29", sha256="3a8005f4f62d35a5aad7e3dbd596890b47c81cc6e34fcfe3dcb93b3ca7cb1246")
|
||||
@ -93,6 +103,10 @@ class PyJaxlib(PythonPackage, CudaPackage, ROCmPackage):
|
||||
for pkg_dep in rocm_dependencies:
|
||||
depends_on(f"{pkg_dep}@6:", when="@0.4.28:")
|
||||
depends_on(pkg_dep)
|
||||
depends_on("rocprofiler-register", when="^hip@6.2:")
|
||||
depends_on("hipblas-common", when="^hip@6.3:")
|
||||
depends_on("hsakmt-roct", when="^hip@:6.2")
|
||||
depends_on("llvm-amdgpu")
|
||||
depends_on("py-nanobind")
|
||||
|
||||
with default_args(type="build"):
|
||||
@ -113,6 +127,7 @@ class PyJaxlib(PythonPackage, CudaPackage, ROCmPackage):
|
||||
depends_on("python@3.9:", when="@0.4.14:")
|
||||
depends_on("python@3.8:", when="@0.4.6:")
|
||||
depends_on("python@:3.13")
|
||||
depends_on("python@:3.12", when="+rocm")
|
||||
depends_on("python@:3.12", when="@:0.4.33")
|
||||
depends_on("python@:3.11", when="@:0.4.16")
|
||||
|
||||
@ -167,6 +182,22 @@ class PyJaxlib(PythonPackage, CudaPackage, ROCmPackage):
|
||||
# Fails to build with freshly released CUDA (#48708).
|
||||
conflicts("^cuda@12.8:", when="@:0.4.31")
|
||||
|
||||
# external CUDA is not supported https://github.com/jax-ml/jax/issues/23689
|
||||
conflicts("+cuda", when="@0.4.32:")
|
||||
|
||||
# aarch64 is not supported https://github.com/jax-ml/jax/issues/25598
|
||||
conflicts("target=aarch64:", when="@0.4.32:")
|
||||
|
||||
resource(
|
||||
name="xla",
|
||||
url="https://github.com/ROCm/xla/archive/07543ab117699a57c1267b453a62f89b1d5953fd.tar.gz",
|
||||
sha256="cee377479654201c61cc3f230d89603cd589525fea2faf44564a23c70ba1448d",
|
||||
expand=True,
|
||||
destination="",
|
||||
placement="xla",
|
||||
when="@0.4.38:0.5.2 +rocm",
|
||||
)
|
||||
|
||||
def url_for_version(self, version):
|
||||
url = "https://github.com/jax-ml/jax/archive/refs/tags/{}-v{}.tar.gz"
|
||||
if version >= Version("0.4.33"):
|
||||
@ -175,6 +206,20 @@ def url_for_version(self, version):
|
||||
name = "jaxlib"
|
||||
return url.format(name, version)
|
||||
|
||||
def setup_build_environment(self, env):
|
||||
spec = self.spec
|
||||
if spec.satisfies("@0.4.38: +rocm") and not spec["hip"].external:
|
||||
if spec.satisfies("^hip@6.2:"):
|
||||
rocm_dependencies.append("rocprofiler-register")
|
||||
if spec.satisfies("^hip@6.3:"):
|
||||
rocm_dependencies.append("hipblas-common")
|
||||
else:
|
||||
rocm_dependencies.append("hsakmt-roct")
|
||||
env.set("LLVM_PATH", spec["llvm-amdgpu"].prefix)
|
||||
for pkg_dep in rocm_dependencies:
|
||||
env.prepend_path("TF_ROCM_MULTIPLE_PATHS", spec[pkg_dep].prefix)
|
||||
env.prune_duplicate_paths("TF_ROCM_MULTIPLE_PATHS")
|
||||
|
||||
def install(self, spec, prefix):
|
||||
# https://jax.readthedocs.io/en/latest/developer.html
|
||||
args = ["build/build.py"]
|
||||
@ -216,7 +261,15 @@ def install(self, spec, prefix):
|
||||
args.append(f"--bazel_options=--repo_env=LOCAL_NCCL_PATH={spec['nccl'].prefix}")
|
||||
|
||||
if "+rocm" in spec:
|
||||
args.extend(["--enable_rocm", f"--rocm_path={self.spec['hip'].prefix}"])
|
||||
args.append(f"--rocm_path={self.spec['hip'].prefix}")
|
||||
if spec.satisfies("@:0.4.35"):
|
||||
args.append("--enable_rocm")
|
||||
if spec.satisfies("@0.4.38:") and not spec["hip"].external:
|
||||
args.append("--bazel_options=--@local_config_rocm//rocm:rocm_path_type=multiple")
|
||||
if spec.satisfies("@0.4.38:0.5.2"):
|
||||
args.append(
|
||||
f"--bazel_options=--override_repository=xla={self.stage.source_path}/xla"
|
||||
)
|
||||
|
||||
args.extend(
|
||||
[
|
||||
@ -227,5 +280,6 @@ def install(self, spec, prefix):
|
||||
)
|
||||
|
||||
python(*args)
|
||||
whl = glob.glob(join_path("dist", "*.whl"))[0]
|
||||
pip(*PythonPipBuilder.std_args(self), f"--prefix={self.prefix}", whl)
|
||||
|
||||
for whl in glob.glob(join_path("dist", "*.whl")):
|
||||
pip(*PythonPipBuilder.std_args(self), f"--prefix={self.prefix}", whl)
|
||||
|
Loading…
Reference in New Issue
Block a user