PyTorch: build flash attention by default, except in CI (#48521)

* PyTorch: build flash attention by default, except in CI

* Variant is boolean, only available when +cuda/+rocm

* desc -> _desc
This commit is contained in:
Adam J. Stewart 2025-02-11 22:20:10 +01:00 committed by GitHub
parent 18cd922aab
commit 46f5b192ef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 28 additions and 5 deletions

View File

@ -12,6 +12,13 @@ spack:
require: ~cuda
mpi:
require: openmpi
py-torch:
require:
- target=aarch64
- ~rocm
- +cuda
- cuda_arch=80
- ~flash_attention
specs:
# Horovod

View File

@ -12,6 +12,13 @@ spack:
require: ~cuda
mpi:
require: openmpi
py-torch:
require:
- target=x86_64_v3
- ~rocm
- +cuda
- cuda_arch=80
- ~flash_attention
specs:
# Horovod

View File

@ -11,6 +11,13 @@ spack:
require: "osmesa"
mpi:
require: openmpi
py-torch:
require:
- target=x86_64_v3
- ~cuda
- +rocm
- amdgpu_target=gfx90a
- ~flash_attention
specs:
# Horovod

View File

@ -114,6 +114,12 @@ class PyTorch(PythonPackage, CudaPackage, ROCmPackage):
description="Enable breakpad crash dump library",
when="@1.10:1.11",
)
# Flash attention has very high memory requirements that may cause the build to fail
# https://github.com/pytorch/pytorch/issues/111526
# https://github.com/pytorch/pytorch/issues/124018
_desc = "Build the flash_attention kernel for scaled dot product attention"
variant("flash_attention", default=True, description=_desc, when="@1.13:+cuda")
variant("flash_attention", default=True, description=_desc, when="@1.13:+rocm")
# py-torch has strict dependencies on old protobuf/py-protobuf versions that
# cause problems with other packages that require newer versions of protobuf
# and py-protobuf --> provide an option to use the internal/vendored protobuf.
@ -594,17 +600,13 @@ def enable_or_disable(variant, keyword="USE", var=None):
env.set("CUDNN_INCLUDE_DIR", self.spec["cudnn"].prefix.include)
env.set("CUDNN_LIBRARY", self.spec["cudnn"].libs[0])
# Flash attention has very high memory requirements that may cause the build to fail
# https://github.com/pytorch/pytorch/issues/111526
# https://github.com/pytorch/pytorch/issues/124018
env.set("USE_FLASH_ATTENTION", "OFF")
enable_or_disable("fbgemm")
enable_or_disable("kineto")
enable_or_disable("magma")
enable_or_disable("metal")
enable_or_disable("mps")
enable_or_disable("breakpad")
enable_or_disable("flash_attention")
enable_or_disable("nccl")
if "+cuda+nccl" in self.spec: