py-torch: fix +rocm+nccl build (#32771)

This commit is contained in:
Adam J. Stewart 2022-09-29 04:01:32 -05:00 committed by GitHub
parent 77afad229c
commit bc039524da
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -10,7 +10,7 @@
from spack.package import *
class PyTorch(PythonPackage, CudaPackage):
class PyTorch(PythonPackage, CudaPackage, ROCmPackage):
"""Tensors and Dynamic neural networks in Python
with strong GPU acceleration."""
@ -100,6 +100,7 @@ class PyTorch(PythonPackage, CudaPackage):
)
conflicts("+cuda+rocm")
conflicts("+tensorpipe", when="+rocm", msg="TensorPipe doesn't yet support ROCm")
conflicts("+breakpad", when="target=ppc64:")
conflicts("+breakpad", when="target=ppc64le:")
@ -177,14 +178,14 @@ class PyTorch(PythonPackage, CudaPackage):
depends_on("cudnn@7:", when="@1.6:+cudnn")
depends_on("magma+cuda", when="+magma+cuda")
depends_on("magma+rocm", when="+magma+rocm")
depends_on("nccl", when="+nccl")
depends_on("nccl", when="+nccl+cuda")
depends_on("numactl", when="+numa")
depends_on("llvm-openmp", when="%apple-clang +openmp")
depends_on("valgrind", when="+valgrind")
with when("+rocm"):
depends_on("hsa-rocr-dev")
depends_on("hip")
depends_on("rccl")
depends_on("rccl", when="+nccl")
depends_on("rocprim")
depends_on("hipcub")
depends_on("rocthrust")
@ -423,6 +424,7 @@ def enable_or_disable(variant, keyword="USE", var=None, newer=False):
enable_or_disable("rocm")
if "+rocm" in self.spec:
env.set("PYTORCH_ROCM_ARCH", ";".join(self.spec.variants["amdgpu_target"].value))
env.set("HSA_PATH", self.spec["hsa-rocr-dev"].prefix)
env.set("ROCBLAS_PATH", self.spec["rocblas"].prefix)
env.set("ROCFFT_PATH", self.spec["rocfft"].prefix)
@ -432,7 +434,8 @@ def enable_or_disable(variant, keyword="USE", var=None, newer=False):
env.set("HIPRAND_PATH", self.spec["rocrand"].prefix)
env.set("ROCRAND_PATH", self.spec["rocrand"].prefix)
env.set("MIOPEN_PATH", self.spec["miopen-hip"].prefix)
env.set("RCCL_PATH", self.spec["rccl"].prefix)
if "+nccl" in self.spec:
env.set("RCCL_PATH", self.spec["rccl"].prefix)
env.set("ROCPRIM_PATH", self.spec["rocprim"].prefix)
env.set("HIPCUB_PATH", self.spec["hipcub"].prefix)
env.set("ROCTHRUST_PATH", self.spec["rocthrust"].prefix)
@ -454,7 +457,7 @@ def enable_or_disable(variant, keyword="USE", var=None, newer=False):
enable_or_disable("breakpad")
enable_or_disable("nccl")
if "+nccl" in self.spec:
if "+cuda+nccl" in self.spec:
env.set("NCCL_LIB_DIR", self.spec["nccl"].libs.directories[0])
env.set("NCCL_INCLUDE_DIR", self.spec["nccl"].prefix.include)