strumpack: Propagate cuda_arch to slate (#50460)

This commit is contained in:
Hugh Carson 2025-05-20 10:51:09 -04:00 committed by GitHub
parent b77d9b87f8
commit 3f00eeabd2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -102,12 +102,12 @@ class Strumpack(CMakePackage, CudaPackage, ROCmPackage):
depends_on("slate", when="+slate")
depends_on("magma+cuda", when="+magma+cuda")
depends_on("magma+rocm", when="+magma+rocm")
depends_on("slate+cuda", when="+cuda+slate")
depends_on("slate+rocm", when="+rocm+slate")
for val in ROCmPackage.amdgpu_targets:
depends_on(
"slate amdgpu_target={0}".format(val), when="+slate amdgpu_target={0}".format(val)
)
with when("+slate+cuda"):
for val in CudaPackage.cuda_arch_values:
depends_on(f"slate +cuda cuda_arch={val}", when=f"cuda_arch={val}")
with when("+slate+rocm"):
for val in ROCmPackage.amdgpu_targets:
depends_on(f"slate +rocm amdgpu_target={val}", when=f"amdgpu_target={val}")
conflicts("+parmetis", when="~mpi")
conflicts("+butterflypack", when="~mpi")
@ -166,24 +166,24 @@ def cmake_args(self):
if "+cuda" in spec:
args.extend(
[
"-DCUDA_TOOLKIT_ROOT_DIR={0}".format(spec["cuda"].prefix),
"-DCMAKE_CUDA_HOST_COMPILER={0}".format(env["SPACK_CXX"]),
f"-DCUDA_TOOLKIT_ROOT_DIR={spec['cuda'].prefix}",
f"-DCMAKE_CUDA_HOST_COMPILER={env['SPACK_CXX']}",
]
)
cuda_archs = spec.variants["cuda_arch"].value
if "none" not in cuda_archs:
args.append("-DCUDA_NVCC_FLAGS={0}".format(" ".join(self.cuda_flags(cuda_archs))))
args.append(f"-DCUDA_NVCC_FLAGS={' '.join(self.cuda_flags(cuda_archs))}")
if "+rocm" in spec:
args.append("-DCMAKE_CXX_COMPILER={0}".format(spec["hip"].hipcc))
args.append("-DHIP_ROOT_DIR={0}".format(spec["hip"].prefix))
args.append(f"-DCMAKE_CXX_COMPILER={spec['hip'].hipcc}")
args.append(f"-DHIP_ROOT_DIR={spec['hip'].prefix}")
rocm_archs = spec.variants["amdgpu_target"].value
hipcc_flags = []
if spec.satisfies("@7.0.1: +rocm"):
hipcc_flags.append("-std=c++14")
if "none" not in rocm_archs:
hipcc_flags.append("--amdgpu-target={0}".format(",".join(rocm_archs)))
args.append("-DHIP_HIPCC_FLAGS={0}".format(" ".join(hipcc_flags)))
hipcc_flags.append(f"--amdgpu-target={','.join(rocm_archs)}")
args.append(f"-DHIP_HIPCC_FLAGS={' '.join(hipcc_flags)}")
if "%cce" in spec:
# Assume the proper Cray CCE module (cce) is loaded:
@ -220,10 +220,8 @@ def _test_example(self, test_prog, test_cmd, pre_args=[]):
mkfile.write("cmake_minimum_required(VERSION 3.15)\n")
mkfile.write("project(StrumpackSmokeTest LANGUAGES CXX)\n")
mkfile.write("find_package(STRUMPACK REQUIRED)\n")
mkfile.write("add_executable({0} {0}.cpp)\n".format(test_prog))
mkfile.write(
"target_link_libraries({0} ".format(test_prog) + "PRIVATE STRUMPACK::strumpack)\n"
)
mkfile.write(f"add_executable({test_prog} {test_prog}.cpp)\n")
mkfile.write(f"target_link_libraries({test_prog} PRIVATE STRUMPACK::strumpack)\n")
with working_dir(test_dir):
opts = self.std_cmake_args + self.cmake_args() + ["."]