py-torch-nvidia-apex: fix +cuda build (#33070)
This commit is contained in:
parent
cbc867a24c
commit
8d9a035d12
@ -0,0 +1,13 @@
|
||||
diff --git a/setup.py b/setup.py
|
||||
index 063b42d..7388297 100644
|
||||
--- a/setup.py
|
||||
+++ b/setup.py
|
||||
@@ -31,7 +31,7 @@ if not torch.cuda.is_available():
|
||||
'and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n'
|
||||
'If you wish to cross-compile for a single specific architecture,\n'
|
||||
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n')
|
||||
- if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
|
||||
+ if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and cpp_extension.CUDA_HOME is not None:
|
||||
_, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
|
||||
if int(bare_metal_major) == 11:
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0"
|
@ -26,7 +26,8 @@ class PyTorchNvidiaApex(PythonPackage, CudaPackage):
|
||||
variant("cuda", default=True, description="Build with CUDA")
|
||||
|
||||
# https://github.com/NVIDIA/apex/issues/1498
|
||||
conflicts("~cuda")
|
||||
# https://github.com/NVIDIA/apex/pull/1499
|
||||
patch("1499.patch", when="@2020-10-19")
|
||||
|
||||
def setup_build_environment(self, env):
|
||||
if "+cuda" in self.spec:
|
||||
@ -37,6 +38,8 @@ def setup_build_environment(self, env):
|
||||
for i in self.spec.variants["cuda_arch"].value
|
||||
)
|
||||
env.set("TORCH_CUDA_ARCH_LIST", torch_cuda_arch)
|
||||
else:
|
||||
env.unset("CUDA_HOME")
|
||||
|
||||
def global_options(self, spec, prefix):
|
||||
args = []
|
||||
|
Loading…
Reference in New Issue
Block a user