py-torch-nvidia-apex: fix +cuda build (#33070)
This commit is contained in:
		| @@ -0,0 +1,13 @@ | |||||||
|  | diff --git a/setup.py b/setup.py | ||||||
|  | index 063b42d..7388297 100644 | ||||||
|  | --- a/setup.py | ||||||
|  | +++ b/setup.py | ||||||
|  | @@ -31,7 +31,7 @@ if not torch.cuda.is_available(): | ||||||
|  |            'and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n' | ||||||
|  |            'If you wish to cross-compile for a single specific architecture,\n' | ||||||
|  |            'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n') | ||||||
|  | -    if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None: | ||||||
|  | +    if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and cpp_extension.CUDA_HOME is not None: | ||||||
|  |          _, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME) | ||||||
|  |          if int(bare_metal_major) == 11: | ||||||
|  |              os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0" | ||||||
| @@ -26,7 +26,8 @@ class PyTorchNvidiaApex(PythonPackage, CudaPackage): | |||||||
|     variant("cuda", default=True, description="Build with CUDA") |     variant("cuda", default=True, description="Build with CUDA") | ||||||
| 
 | 
 | ||||||
|     # https://github.com/NVIDIA/apex/issues/1498 |     # https://github.com/NVIDIA/apex/issues/1498 | ||||||
|     conflicts("~cuda") |     # https://github.com/NVIDIA/apex/pull/1499 | ||||||
|  |     patch("1499.patch", when="@2020-10-19") | ||||||
| 
 | 
 | ||||||
|     def setup_build_environment(self, env): |     def setup_build_environment(self, env): | ||||||
|         if "+cuda" in self.spec: |         if "+cuda" in self.spec: | ||||||
| @@ -37,6 +38,8 @@ def setup_build_environment(self, env): | |||||||
|                     for i in self.spec.variants["cuda_arch"].value |                     for i in self.spec.variants["cuda_arch"].value | ||||||
|                 ) |                 ) | ||||||
|                 env.set("TORCH_CUDA_ARCH_LIST", torch_cuda_arch) |                 env.set("TORCH_CUDA_ARCH_LIST", torch_cuda_arch) | ||||||
|  |         else: | ||||||
|  |             env.unset("CUDA_HOME") | ||||||
| 
 | 
 | ||||||
|     def global_options(self, spec, prefix): |     def global_options(self, spec, prefix): | ||||||
|         args = [] |         args = [] | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Adam J. Stewart
					Adam J. Stewart