cray-mpich: adding partial GTL support (#45830)

cray-mpich now has a rocm variant. You can use gtl_lib in the
flag_handler like so:

```python
    def flag_handler(self, name, flags):
        wrapper_flags = []
        environment_flags = []
        build_system_flags = []

        if self.spec.satisfies("+rocm"):
            if self.spec.satisfies("^cray-mpich"):
                gtl_lib = self.spec["cray-mpich"].package.gtl_lib
                build_system_flags.extend(gtl_lib.get(name) or [])
            # hipcc is not wrapped, we need to pass the flags via the
            # build system.
            build_system_flags.extend(flags)

        return (wrapper_flags, environment_flags, build_system_flags)
```

---------

Co-authored-by: Richard Berger <rberger@lanl.gov>
Co-authored-by: Massimiliano Culpo <massimiliano.culpo@gmail.com>
Co-authored-by: Richard Berger <richard.berger@outlook.com>
This commit is contained in:
etiennemlb 2025-02-12 21:01:40 +01:00 committed by GitHub
parent 6f1dce95f9
commit ae50757f3c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4,20 +4,24 @@
import os
import llnl.util.tty as tty
from spack.package import *
from spack.pkg.builtin.mpich import MpichEnvironmentModifications
from spack.util.module_cmd import get_path_args_from_module_line, module
class CrayMpich(MpichEnvironmentModifications, Package):
class CrayMpich(MpichEnvironmentModifications, Package, CudaPackage, ROCmPackage):
"""Cray's MPICH is a high performance and widely portable implementation of
the Message Passing Interface (MPI) standard."""
homepage = "https://docs.nersc.gov/development/compilers/wrappers/"
has_code = False # Skip attempts to fetch source that is not available
maintainers("haampie")
maintainers("etiennemlb", "haampie")
version("8.1.30")
version("8.1.28")
version("8.1.25")
version("8.1.24")
version("8.1.21")
@ -123,3 +127,86 @@ def libs(self):
libs += find_libraries(libraries, root=self.prefix.lib64, recursive=True)
return libs
@property
def gtl_lib(self):
# GPU transport Layer (GTL) handling background:
# - The cray-mpich module defines an environment variable per supported
# GPU (say, PE_MPICH_GTL_LIBS_amd_gfx942). So we should read the
# appropriate variable.
# In practice loading a module and checking its content is a PITA. We
# simplify by assuming that the GTL for a given vendor (say, AMD), is
# one and the same for all the targets of this vendor (one GTL for all
# Nvidia or one GTL for all AMD devices).
# - Second, except if you have a very weird mpich layout, the GTL are
# located in /opt/cray/pe/mpich/<cray_mpich_version>/gtl/lib when the
# MPI libraries are in
# /opt/cray/pe/mpich/<cray_mpich_version>/ofi/<vendor>/<vendor_version>.
# Example:
# /opt/cray/pe/mpich/8.1.28/gtl/lib
# /opt/cray/pe/mpich/8.1.28/ofi/<vendor>/<vendor_version>
# /opt/cray/pe/mpich/8.1.28/ofi/<vendor>/<vendor_version>/../../../gtl/lib
gtl_kinds = {
"cuda": {
"lib": "libmpi_gtl_cuda",
"variant": "cuda_arch",
"values": {"70", "80", "90"},
},
"rocm": {
"lib": "libmpi_gtl_hsa",
"variant": "amdgpu_target",
"values": {"gfx906", "gfx908", "gfx90a", "gfx940", "gfx942"},
},
}
for variant, gtl_kind in gtl_kinds.items():
arch_variant = gtl_kind["variant"]
arch_values = gtl_kind["values"]
gtl_lib = gtl_kind["lib"]
if self.spec.satisfies(f"+{variant} {arch_variant}=*"):
accelerator_architecture_set = set(self.spec.variants[arch_variant].value)
if len(
accelerator_architecture_set
) >= 1 and not accelerator_architecture_set.issubset(arch_values):
raise InstallError(
f"cray-mpich variant '+{variant} {arch_variant}'"
" was specified but no GTL support could be found for it."
)
mpi_root = os.path.abspath(
os.path.join(self.prefix, os.pardir, os.pardir, os.pardir)
)
gtl_root = os.path.join(mpi_root, "gtl", "lib")
gtl_shared_libraries = find_libraries(
[gtl_lib], root=gtl_root, shared=True, recursive=False
)
if len(gtl_shared_libraries) != 1:
raise InstallError(
f"cray-mpich variant '+{variant} {arch_variant}'"
" was specified and GTL support was found for it but"
f" the '{gtl_lib}' could not be correctly found on disk."
)
gtl_library_fullpath = list(gtl_shared_libraries)[0]
tty.debug(f"Selected GTL: {gtl_library_fullpath}")
gtl_library_directory = os.path.dirname(gtl_library_fullpath)
gtl_library_name = os.path.splitext(
os.path.basename(gtl_library_fullpath).split("lib")[1]
)[0]
# Early break. Only one GTL can be active at a given time.
return {
"ldflags": [
f"-L{gtl_library_directory}",
f"-Wl,-rpath,{gtl_library_directory}",
],
"ldlibs": [f"-l{gtl_library_name}"],
}
return {}