ucc: add ROCm and rccl support (#46580)

This commit is contained in:
afzpatel 2024-12-02 14:43:53 -05:00 committed by GitHub
parent 5ddbb1566d
commit 8e5a040985
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -5,7 +5,7 @@
from spack.package import *
class Ucc(AutotoolsPackage, CudaPackage):
class Ucc(AutotoolsPackage, CudaPackage, ROCmPackage):
"""UCC is a collective communication operations API and library that is
flexible, complete, and feature-rich for current and emerging programming
models and runtimes."""
@ -23,8 +23,7 @@ class Ucc(AutotoolsPackage, CudaPackage):
variant("cuda", default=False, description="Enable CUDA TL")
variant("nccl", default=False, description="Enable NCCL TL", when="+cuda")
# RCCL build not tested
# variant("rccl", default=False, description="Enable RCCL TL")
variant("rccl", default=False, description="Enable RCCL TL", when="+rocm")
# https://github.com/openucx/ucc/pull/847
patch(
@ -40,7 +39,7 @@ class Ucc(AutotoolsPackage, CudaPackage):
depends_on("ucx")
depends_on("nccl", when="+nccl")
# depends_on("rccl", when="+rccl")
depends_on("rccl", when="+rccl")
with when("+nccl"):
for arch in CudaPackage.cuda_arch_values:
@ -55,5 +54,26 @@ def configure_args(self):
args = []
args.extend(self.with_or_without("cuda", activation_value="prefix"))
args.extend(self.with_or_without("nccl", activation_value="prefix"))
# args.extend(self.with_or_without("rccl", activation_value="prefix"))
if self.spec.satisfies("+rocm"):
cppflags = " ".join(
"-I" + include_dir
for include_dir in (
self.spec["hip"].prefix.include,
self.spec["hip"].prefix.include.hip,
self.spec["hsa-rocr-dev"].prefix.include.hsa,
)
)
ldflags = " ".join(
"-L" + library_dir
for library_dir in (
self.spec["hip"].prefix.lib,
self.spec["hsa-rocr-dev"].prefix.lib,
)
)
args.extend(["CPPFLAGS=" + cppflags, "LDFLAGS=" + ldflags])
args.append("--with-rocm=" + self.spec["hip"].prefix)
args.append("--with-ucx=" + self.spec["ucx"].prefix)
args.extend(self.with_or_without("rccl", activation_value="prefix"))
else:
args.append("--without-rocm")
return args