From a474034023b177abcc9519dc14fd8da19dc09f39 Mon Sep 17 00:00:00 2001 From: afzpatel <122491982+afzpatel@users.noreply.github.com> Date: Tue, 24 Sep 2024 18:38:58 -0400 Subject: [PATCH] py-jaxlib: add external ROCm support (#46467) * add external ROCm support for py-jaxlib * fix style * remove fork releases --- .../builtin/packages/py-jaxlib/package.py | 29 ++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/var/spack/repos/builtin/packages/py-jaxlib/package.py b/var/spack/repos/builtin/packages/py-jaxlib/package.py index fcd624cdd23..09eb522c56a 100644 --- a/var/spack/repos/builtin/packages/py-jaxlib/package.py +++ b/var/spack/repos/builtin/packages/py-jaxlib/package.py @@ -7,8 +7,25 @@ from spack.package import * +rocm_dependencies = [ + "hsa-rocr-dev", + "hip", + "rccl", + "rocprim", + "hipcub", + "rocthrust", + "roctracer-dev", + "rocrand", + "hipsparse", + "hipfft", + "rocfft", + "rocblas", + "miopen-hip", + "rocminfo", +] -class PyJaxlib(PythonPackage, CudaPackage): + +class PyJaxlib(PythonPackage, CudaPackage, ROCmPackage): """XLA library for Jax""" homepage = "https://github.com/google/jax" @@ -62,6 +79,12 @@ class PyJaxlib(PythonPackage, CudaPackage): depends_on("nccl@2.16:", when="@0.4.18:") depends_on("nccl") + with when("+rocm"): + for pkg_dep in rocm_dependencies: + depends_on(f"{pkg_dep}@6:", when="@0.4.28:") + depends_on(pkg_dep) + depends_on("py-nanobind") + with default_args(type="build"): # .bazelversion depends_on("bazel@6.5.0", when="@0.4.28:") @@ -161,6 +184,10 @@ def install(self, spec, prefix): "--bazel_startup_options=" "--output_user_root={0}".format(self.wrapped_package_object.buildtmp) ) + if "+rocm" in spec: + args.append("--enable_rocm") + args.append("--rocm_path={0}".format(self.spec["hip"].prefix)) + python(*args) with working_dir(self.wrapped_package_object.tmp_path): args = std_pip_args + ["--prefix=" + self.prefix, "."]