py-torch: fix +rocm+nccl build (#32771)
This commit is contained in:
		| @@ -10,7 +10,7 @@ | |||||||
| from spack.package import * | from spack.package import * | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class PyTorch(PythonPackage, CudaPackage): | class PyTorch(PythonPackage, CudaPackage, ROCmPackage): | ||||||
|     """Tensors and Dynamic neural networks in Python |     """Tensors and Dynamic neural networks in Python | ||||||
|     with strong GPU acceleration.""" |     with strong GPU acceleration.""" | ||||||
| 
 | 
 | ||||||
| @@ -100,6 +100,7 @@ class PyTorch(PythonPackage, CudaPackage): | |||||||
|     ) |     ) | ||||||
| 
 | 
 | ||||||
|     conflicts("+cuda+rocm") |     conflicts("+cuda+rocm") | ||||||
|  |     conflicts("+tensorpipe", when="+rocm", msg="TensorPipe doesn't yet support ROCm") | ||||||
|     conflicts("+breakpad", when="target=ppc64:") |     conflicts("+breakpad", when="target=ppc64:") | ||||||
|     conflicts("+breakpad", when="target=ppc64le:") |     conflicts("+breakpad", when="target=ppc64le:") | ||||||
| 
 | 
 | ||||||
| @@ -177,14 +178,14 @@ class PyTorch(PythonPackage, CudaPackage): | |||||||
|     depends_on("cudnn@7:", when="@1.6:+cudnn") |     depends_on("cudnn@7:", when="@1.6:+cudnn") | ||||||
|     depends_on("magma+cuda", when="+magma+cuda") |     depends_on("magma+cuda", when="+magma+cuda") | ||||||
|     depends_on("magma+rocm", when="+magma+rocm") |     depends_on("magma+rocm", when="+magma+rocm") | ||||||
|     depends_on("nccl", when="+nccl") |     depends_on("nccl", when="+nccl+cuda") | ||||||
|     depends_on("numactl", when="+numa") |     depends_on("numactl", when="+numa") | ||||||
|     depends_on("llvm-openmp", when="%apple-clang +openmp") |     depends_on("llvm-openmp", when="%apple-clang +openmp") | ||||||
|     depends_on("valgrind", when="+valgrind") |     depends_on("valgrind", when="+valgrind") | ||||||
|     with when("+rocm"): |     with when("+rocm"): | ||||||
|         depends_on("hsa-rocr-dev") |         depends_on("hsa-rocr-dev") | ||||||
|         depends_on("hip") |         depends_on("hip") | ||||||
|         depends_on("rccl") |         depends_on("rccl", when="+nccl") | ||||||
|         depends_on("rocprim") |         depends_on("rocprim") | ||||||
|         depends_on("hipcub") |         depends_on("hipcub") | ||||||
|         depends_on("rocthrust") |         depends_on("rocthrust") | ||||||
| @@ -423,6 +424,7 @@ def enable_or_disable(variant, keyword="USE", var=None, newer=False): | |||||||
| 
 | 
 | ||||||
|         enable_or_disable("rocm") |         enable_or_disable("rocm") | ||||||
|         if "+rocm" in self.spec: |         if "+rocm" in self.spec: | ||||||
|  |             env.set("PYTORCH_ROCM_ARCH", ";".join(self.spec.variants["amdgpu_target"].value)) | ||||||
|             env.set("HSA_PATH", self.spec["hsa-rocr-dev"].prefix) |             env.set("HSA_PATH", self.spec["hsa-rocr-dev"].prefix) | ||||||
|             env.set("ROCBLAS_PATH", self.spec["rocblas"].prefix) |             env.set("ROCBLAS_PATH", self.spec["rocblas"].prefix) | ||||||
|             env.set("ROCFFT_PATH", self.spec["rocfft"].prefix) |             env.set("ROCFFT_PATH", self.spec["rocfft"].prefix) | ||||||
| @@ -432,6 +434,7 @@ def enable_or_disable(variant, keyword="USE", var=None, newer=False): | |||||||
|             env.set("HIPRAND_PATH", self.spec["rocrand"].prefix) |             env.set("HIPRAND_PATH", self.spec["rocrand"].prefix) | ||||||
|             env.set("ROCRAND_PATH", self.spec["rocrand"].prefix) |             env.set("ROCRAND_PATH", self.spec["rocrand"].prefix) | ||||||
|             env.set("MIOPEN_PATH", self.spec["miopen-hip"].prefix) |             env.set("MIOPEN_PATH", self.spec["miopen-hip"].prefix) | ||||||
|  |             if "+nccl" in self.spec: | ||||||
|                 env.set("RCCL_PATH", self.spec["rccl"].prefix) |                 env.set("RCCL_PATH", self.spec["rccl"].prefix) | ||||||
|             env.set("ROCPRIM_PATH", self.spec["rocprim"].prefix) |             env.set("ROCPRIM_PATH", self.spec["rocprim"].prefix) | ||||||
|             env.set("HIPCUB_PATH", self.spec["hipcub"].prefix) |             env.set("HIPCUB_PATH", self.spec["hipcub"].prefix) | ||||||
| @@ -454,7 +457,7 @@ def enable_or_disable(variant, keyword="USE", var=None, newer=False): | |||||||
|         enable_or_disable("breakpad") |         enable_or_disable("breakpad") | ||||||
| 
 | 
 | ||||||
|         enable_or_disable("nccl") |         enable_or_disable("nccl") | ||||||
|         if "+nccl" in self.spec: |         if "+cuda+nccl" in self.spec: | ||||||
|             env.set("NCCL_LIB_DIR", self.spec["nccl"].libs.directories[0]) |             env.set("NCCL_LIB_DIR", self.spec["nccl"].libs.directories[0]) | ||||||
|             env.set("NCCL_INCLUDE_DIR", self.spec["nccl"].prefix.include) |             env.set("NCCL_INCLUDE_DIR", self.spec["nccl"].prefix.include) | ||||||
| 
 | 
 | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Adam J. Stewart
					Adam J. Stewart