py-torch: fix +rocm+nccl build (#32771)
This commit is contained in:
parent
77afad229c
commit
bc039524da
@ -10,7 +10,7 @@
|
||||
from spack.package import *
|
||||
|
||||
|
||||
class PyTorch(PythonPackage, CudaPackage):
|
||||
class PyTorch(PythonPackage, CudaPackage, ROCmPackage):
|
||||
"""Tensors and Dynamic neural networks in Python
|
||||
with strong GPU acceleration."""
|
||||
|
||||
@ -100,6 +100,7 @@ class PyTorch(PythonPackage, CudaPackage):
|
||||
)
|
||||
|
||||
conflicts("+cuda+rocm")
|
||||
conflicts("+tensorpipe", when="+rocm", msg="TensorPipe doesn't yet support ROCm")
|
||||
conflicts("+breakpad", when="target=ppc64:")
|
||||
conflicts("+breakpad", when="target=ppc64le:")
|
||||
|
||||
@ -177,14 +178,14 @@ class PyTorch(PythonPackage, CudaPackage):
|
||||
depends_on("cudnn@7:", when="@1.6:+cudnn")
|
||||
depends_on("magma+cuda", when="+magma+cuda")
|
||||
depends_on("magma+rocm", when="+magma+rocm")
|
||||
depends_on("nccl", when="+nccl")
|
||||
depends_on("nccl", when="+nccl+cuda")
|
||||
depends_on("numactl", when="+numa")
|
||||
depends_on("llvm-openmp", when="%apple-clang +openmp")
|
||||
depends_on("valgrind", when="+valgrind")
|
||||
with when("+rocm"):
|
||||
depends_on("hsa-rocr-dev")
|
||||
depends_on("hip")
|
||||
depends_on("rccl")
|
||||
depends_on("rccl", when="+nccl")
|
||||
depends_on("rocprim")
|
||||
depends_on("hipcub")
|
||||
depends_on("rocthrust")
|
||||
@ -423,6 +424,7 @@ def enable_or_disable(variant, keyword="USE", var=None, newer=False):
|
||||
|
||||
enable_or_disable("rocm")
|
||||
if "+rocm" in self.spec:
|
||||
env.set("PYTORCH_ROCM_ARCH", ";".join(self.spec.variants["amdgpu_target"].value))
|
||||
env.set("HSA_PATH", self.spec["hsa-rocr-dev"].prefix)
|
||||
env.set("ROCBLAS_PATH", self.spec["rocblas"].prefix)
|
||||
env.set("ROCFFT_PATH", self.spec["rocfft"].prefix)
|
||||
@ -432,7 +434,8 @@ def enable_or_disable(variant, keyword="USE", var=None, newer=False):
|
||||
env.set("HIPRAND_PATH", self.spec["rocrand"].prefix)
|
||||
env.set("ROCRAND_PATH", self.spec["rocrand"].prefix)
|
||||
env.set("MIOPEN_PATH", self.spec["miopen-hip"].prefix)
|
||||
env.set("RCCL_PATH", self.spec["rccl"].prefix)
|
||||
if "+nccl" in self.spec:
|
||||
env.set("RCCL_PATH", self.spec["rccl"].prefix)
|
||||
env.set("ROCPRIM_PATH", self.spec["rocprim"].prefix)
|
||||
env.set("HIPCUB_PATH", self.spec["hipcub"].prefix)
|
||||
env.set("ROCTHRUST_PATH", self.spec["rocthrust"].prefix)
|
||||
@ -454,7 +457,7 @@ def enable_or_disable(variant, keyword="USE", var=None, newer=False):
|
||||
enable_or_disable("breakpad")
|
||||
|
||||
enable_or_disable("nccl")
|
||||
if "+nccl" in self.spec:
|
||||
if "+cuda+nccl" in self.spec:
|
||||
env.set("NCCL_LIB_DIR", self.spec["nccl"].libs.directories[0])
|
||||
env.set("NCCL_INCLUDE_DIR", self.spec["nccl"].prefix.include)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user