py-jaxlib: add external ROCm support (#46467)
* add external ROCm support for py-jaxlib * fix style * remove fork releases
This commit is contained in:
parent
022eca1cfe
commit
a474034023
@ -7,8 +7,25 @@
|
||||
|
||||
from spack.package import *
|
||||
|
||||
rocm_dependencies = [
|
||||
"hsa-rocr-dev",
|
||||
"hip",
|
||||
"rccl",
|
||||
"rocprim",
|
||||
"hipcub",
|
||||
"rocthrust",
|
||||
"roctracer-dev",
|
||||
"rocrand",
|
||||
"hipsparse",
|
||||
"hipfft",
|
||||
"rocfft",
|
||||
"rocblas",
|
||||
"miopen-hip",
|
||||
"rocminfo",
|
||||
]
|
||||
|
||||
class PyJaxlib(PythonPackage, CudaPackage):
|
||||
|
||||
class PyJaxlib(PythonPackage, CudaPackage, ROCmPackage):
|
||||
"""XLA library for Jax"""
|
||||
|
||||
homepage = "https://github.com/google/jax"
|
||||
@ -62,6 +79,12 @@ class PyJaxlib(PythonPackage, CudaPackage):
|
||||
depends_on("nccl@2.16:", when="@0.4.18:")
|
||||
depends_on("nccl")
|
||||
|
||||
with when("+rocm"):
|
||||
for pkg_dep in rocm_dependencies:
|
||||
depends_on(f"{pkg_dep}@6:", when="@0.4.28:")
|
||||
depends_on(pkg_dep)
|
||||
depends_on("py-nanobind")
|
||||
|
||||
with default_args(type="build"):
|
||||
# .bazelversion
|
||||
depends_on("bazel@6.5.0", when="@0.4.28:")
|
||||
@ -161,6 +184,10 @@ def install(self, spec, prefix):
|
||||
"--bazel_startup_options="
|
||||
"--output_user_root={0}".format(self.wrapped_package_object.buildtmp)
|
||||
)
|
||||
if "+rocm" in spec:
|
||||
args.append("--enable_rocm")
|
||||
args.append("--rocm_path={0}".format(self.spec["hip"].prefix))
|
||||
|
||||
python(*args)
|
||||
with working_dir(self.wrapped_package_object.tmp_path):
|
||||
args = std_pip_args + ["--prefix=" + self.prefix, "."]
|
||||
|
Loading…
Reference in New Issue
Block a user