Check for LSF, FLux, and Slurm when determing MPI exec

This commit is contained in:
Tara Drwenski 2025-03-19 14:16:41 -06:00
parent d199738f31
commit ba86a0eb48

View File

@ -2,17 +2,19 @@
# #
# SPDX-License-Identifier: (Apache-2.0 OR MIT) # SPDX-License-Identifier: (Apache-2.0 OR MIT)
import collections.abc import collections.abc
import enum
import os import os
import re import re
from typing import Tuple from typing import Optional, Tuple
import llnl.util.filesystem as fs import llnl.util.filesystem as fs
import llnl.util.tty as tty import llnl.util.tty as tty
import spack.phase_callbacks import spack.phase_callbacks
import spack.spec
import spack.util.prefix import spack.util.prefix
from spack.directives import depends_on from spack.directives import depends_on
from spack.spec import Spec
from spack.util.executable import which_string
from .cmake import CMakeBuilder, CMakePackage from .cmake import CMakeBuilder, CMakePackage
@ -48,6 +50,62 @@ def cmake_cache_filepath(name, value, comment=""):
return 'set({0} "{1}" CACHE FILEPATH "{2}")\n'.format(name, value, comment) return 'set({0} "{1}" CACHE FILEPATH "{2}")\n'.format(name, value, comment)
class Scheduler(enum.Enum):
LSF = enum.auto()
SLURM = enum.auto()
FLUX = enum.auto()
def get_scheduler(spec: Spec) -> Optional[Scheduler]:
if spec.satisfies("^spectrum-mpi") or spec["mpi"].satisfies("schedulers=lsf"):
return Scheduler.LSF
slurm_checks = ["+slurm", "schedulers=slurm", "process_managers=slurm"]
if any(spec["mpi"].satisfies(variant) for variant in slurm_checks):
return Scheduler.SLURM
# TODO improve this when MPI implementations support flux
if which_string("flux") is not None:
return Scheduler.FLUX
return None
def get_mpi_exec(spec: Spec) -> Optional[str]:
scheduler = get_scheduler(spec)
if scheduler == Scheduler.LSF:
return which_string("lrun")
elif scheduler == Scheduler.SLURM:
if spec["mpi"].external:
return which_string("srun")
else:
return os.path.join(spec["slurm"].prefix.bin, "srun")
elif scheduler == Scheduler.FLUX:
flux = which_string("flux")
return f"{flux};run" if flux else None
elif hasattr(spec["mpi"].package, "mpiexec"):
return spec["mpi"].package.mpiexec
else:
mpiexec = os.path.join(spec["mpi"].prefix.bin, "mpirun")
if not os.path.exists(mpiexec):
mpiexec = os.path.join(spec["mpi"].prefix.bin, "mpiexec")
return mpiexec
def get_mpi_exec_num_proc(spec: Spec) -> str:
scheduler = get_scheduler(spec)
if scheduler in [Scheduler.FLUX, Scheduler.LSF, Scheduler.SLURM]:
return "-n"
else:
return "-np"
class CachedCMakeBuilder(CMakeBuilder): class CachedCMakeBuilder(CMakeBuilder):
#: Phases of a Cached CMake package #: Phases of a Cached CMake package
#: Note: the initconfig phase is used for developer builds as a final phase to stop on #: Note: the initconfig phase is used for developer builds as a final phase to stop on
@ -199,27 +257,10 @@ def initconfig_mpi_entries(self):
if hasattr(spec["mpi"], "mpifc"): if hasattr(spec["mpi"], "mpifc"):
entries.append(cmake_cache_path("MPI_Fortran_COMPILER", spec["mpi"].mpifc)) entries.append(cmake_cache_path("MPI_Fortran_COMPILER", spec["mpi"].mpifc))
# Check for slurm
using_slurm = False
slurm_checks = ["+slurm", "schedulers=slurm", "process_managers=slurm"]
if any(spec["mpi"].satisfies(variant) for variant in slurm_checks):
using_slurm = True
# Determine MPIEXEC # Determine MPIEXEC
if using_slurm: mpiexec = get_mpi_exec(spec)
if spec["mpi"].external:
# Heuristic until we have dependents on externals
mpiexec = "/usr/bin/srun"
else:
mpiexec = os.path.join(spec["slurm"].prefix.bin, "srun")
elif hasattr(spec["mpi"].package, "mpiexec"):
mpiexec = spec["mpi"].package.mpiexec
else:
mpiexec = os.path.join(spec["mpi"].prefix.bin, "mpirun")
if not os.path.exists(mpiexec):
mpiexec = os.path.join(spec["mpi"].prefix.bin, "mpiexec")
if not os.path.exists(mpiexec): if mpiexec is None or not os.path.exists(mpiexec.split(";")[0]):
msg = "Unable to determine MPIEXEC, %s tests may fail" % self.pkg.name msg = "Unable to determine MPIEXEC, %s tests may fail" % self.pkg.name
entries.append("# {0}\n".format(msg)) entries.append("# {0}\n".format(msg))
tty.warn(msg) tty.warn(msg)
@ -232,10 +273,7 @@ def initconfig_mpi_entries(self):
entries.append(cmake_cache_path("MPIEXEC", mpiexec)) entries.append(cmake_cache_path("MPIEXEC", mpiexec))
# Determine MPIEXEC_NUMPROC_FLAG # Determine MPIEXEC_NUMPROC_FLAG
if using_slurm: entries.append(cmake_cache_string("MPIEXEC_NUMPROC_FLAG", get_mpi_exec_num_proc(spec)))
entries.append(cmake_cache_string("MPIEXEC_NUMPROC_FLAG", "-n"))
else:
entries.append(cmake_cache_string("MPIEXEC_NUMPROC_FLAG", "-np"))
return entries return entries
@ -341,7 +379,7 @@ def initconfig_package_entries(self):
return [] return []
def initconfig( def initconfig(
self, pkg: "CachedCMakePackage", spec: spack.spec.Spec, prefix: spack.util.prefix.Prefix self, pkg: "CachedCMakePackage", spec: Spec, prefix: spack.util.prefix.Prefix
) -> None: ) -> None:
cache_entries = ( cache_entries = (
self.std_initconfig_entries() self.std_initconfig_entries()