py-jaxlib: add external ROCm support (#46467)

* add external ROCm support for py-jaxlib

* fix style

* remove fork releases
This commit is contained in:
afzpatel 2024-09-24 18:38:58 -04:00 committed by GitHub
parent 022eca1cfe
commit a474034023
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -7,8 +7,25 @@
from spack.package import *
rocm_dependencies = [
"hsa-rocr-dev",
"hip",
"rccl",
"rocprim",
"hipcub",
"rocthrust",
"roctracer-dev",
"rocrand",
"hipsparse",
"hipfft",
"rocfft",
"rocblas",
"miopen-hip",
"rocminfo",
]
class PyJaxlib(PythonPackage, CudaPackage):
class PyJaxlib(PythonPackage, CudaPackage, ROCmPackage):
"""XLA library for Jax"""
homepage = "https://github.com/google/jax"
@ -62,6 +79,12 @@ class PyJaxlib(PythonPackage, CudaPackage):
depends_on("nccl@2.16:", when="@0.4.18:")
depends_on("nccl")
with when("+rocm"):
for pkg_dep in rocm_dependencies:
depends_on(f"{pkg_dep}@6:", when="@0.4.28:")
depends_on(pkg_dep)
depends_on("py-nanobind")
with default_args(type="build"):
# .bazelversion
depends_on("bazel@6.5.0", when="@0.4.28:")
@ -161,6 +184,10 @@ def install(self, spec, prefix):
"--bazel_startup_options="
"--output_user_root={0}".format(self.wrapped_package_object.buildtmp)
)
if "+rocm" in spec:
args.append("--enable_rocm")
args.append("--rocm_path={0}".format(self.spec["hip"].prefix))
python(*args)
with working_dir(self.wrapped_package_object.tmp_path):
args = std_pip_args + ["--prefix=" + self.prefix, "."]