py-jaxlib: add external ROCm support (#46467)
* add external ROCm support for py-jaxlib * fix style * remove fork releases
This commit is contained in:
parent
022eca1cfe
commit
a474034023
@ -7,8 +7,25 @@
|
|||||||
|
|
||||||
from spack.package import *
|
from spack.package import *
|
||||||
|
|
||||||
|
rocm_dependencies = [
|
||||||
|
"hsa-rocr-dev",
|
||||||
|
"hip",
|
||||||
|
"rccl",
|
||||||
|
"rocprim",
|
||||||
|
"hipcub",
|
||||||
|
"rocthrust",
|
||||||
|
"roctracer-dev",
|
||||||
|
"rocrand",
|
||||||
|
"hipsparse",
|
||||||
|
"hipfft",
|
||||||
|
"rocfft",
|
||||||
|
"rocblas",
|
||||||
|
"miopen-hip",
|
||||||
|
"rocminfo",
|
||||||
|
]
|
||||||
|
|
||||||
class PyJaxlib(PythonPackage, CudaPackage):
|
|
||||||
|
class PyJaxlib(PythonPackage, CudaPackage, ROCmPackage):
|
||||||
"""XLA library for Jax"""
|
"""XLA library for Jax"""
|
||||||
|
|
||||||
homepage = "https://github.com/google/jax"
|
homepage = "https://github.com/google/jax"
|
||||||
@ -62,6 +79,12 @@ class PyJaxlib(PythonPackage, CudaPackage):
|
|||||||
depends_on("nccl@2.16:", when="@0.4.18:")
|
depends_on("nccl@2.16:", when="@0.4.18:")
|
||||||
depends_on("nccl")
|
depends_on("nccl")
|
||||||
|
|
||||||
|
with when("+rocm"):
|
||||||
|
for pkg_dep in rocm_dependencies:
|
||||||
|
depends_on(f"{pkg_dep}@6:", when="@0.4.28:")
|
||||||
|
depends_on(pkg_dep)
|
||||||
|
depends_on("py-nanobind")
|
||||||
|
|
||||||
with default_args(type="build"):
|
with default_args(type="build"):
|
||||||
# .bazelversion
|
# .bazelversion
|
||||||
depends_on("bazel@6.5.0", when="@0.4.28:")
|
depends_on("bazel@6.5.0", when="@0.4.28:")
|
||||||
@ -161,6 +184,10 @@ def install(self, spec, prefix):
|
|||||||
"--bazel_startup_options="
|
"--bazel_startup_options="
|
||||||
"--output_user_root={0}".format(self.wrapped_package_object.buildtmp)
|
"--output_user_root={0}".format(self.wrapped_package_object.buildtmp)
|
||||||
)
|
)
|
||||||
|
if "+rocm" in spec:
|
||||||
|
args.append("--enable_rocm")
|
||||||
|
args.append("--rocm_path={0}".format(self.spec["hip"].prefix))
|
||||||
|
|
||||||
python(*args)
|
python(*args)
|
||||||
with working_dir(self.wrapped_package_object.tmp_path):
|
with working_dir(self.wrapped_package_object.tmp_path):
|
||||||
args = std_pip_args + ["--prefix=" + self.prefix, "."]
|
args = std_pip_args + ["--prefix=" + self.prefix, "."]
|
||||||
|
Loading…
Reference in New Issue
Block a user