PyTorch: build flash attention by default, except in CI (#48521)
* PyTorch: build flash attention by default, except in CI * Variant is boolean, only available when +cuda/+rocm * desc -> _desc
This commit is contained in:
parent
18cd922aab
commit
46f5b192ef
@ -12,6 +12,13 @@ spack:
|
||||
require: ~cuda
|
||||
mpi:
|
||||
require: openmpi
|
||||
py-torch:
|
||||
require:
|
||||
- target=aarch64
|
||||
- ~rocm
|
||||
- +cuda
|
||||
- cuda_arch=80
|
||||
- ~flash_attention
|
||||
|
||||
specs:
|
||||
# Horovod
|
||||
|
@ -12,6 +12,13 @@ spack:
|
||||
require: ~cuda
|
||||
mpi:
|
||||
require: openmpi
|
||||
py-torch:
|
||||
require:
|
||||
- target=x86_64_v3
|
||||
- ~rocm
|
||||
- +cuda
|
||||
- cuda_arch=80
|
||||
- ~flash_attention
|
||||
|
||||
specs:
|
||||
# Horovod
|
||||
|
@ -11,6 +11,13 @@ spack:
|
||||
require: "osmesa"
|
||||
mpi:
|
||||
require: openmpi
|
||||
py-torch:
|
||||
require:
|
||||
- target=x86_64_v3
|
||||
- ~cuda
|
||||
- +rocm
|
||||
- amdgpu_target=gfx90a
|
||||
- ~flash_attention
|
||||
|
||||
specs:
|
||||
# Horovod
|
||||
|
@ -114,6 +114,12 @@ class PyTorch(PythonPackage, CudaPackage, ROCmPackage):
|
||||
description="Enable breakpad crash dump library",
|
||||
when="@1.10:1.11",
|
||||
)
|
||||
# Flash attention has very high memory requirements that may cause the build to fail
|
||||
# https://github.com/pytorch/pytorch/issues/111526
|
||||
# https://github.com/pytorch/pytorch/issues/124018
|
||||
_desc = "Build the flash_attention kernel for scaled dot product attention"
|
||||
variant("flash_attention", default=True, description=_desc, when="@1.13:+cuda")
|
||||
variant("flash_attention", default=True, description=_desc, when="@1.13:+rocm")
|
||||
# py-torch has strict dependencies on old protobuf/py-protobuf versions that
|
||||
# cause problems with other packages that require newer versions of protobuf
|
||||
# and py-protobuf --> provide an option to use the internal/vendored protobuf.
|
||||
@ -594,17 +600,13 @@ def enable_or_disable(variant, keyword="USE", var=None):
|
||||
env.set("CUDNN_INCLUDE_DIR", self.spec["cudnn"].prefix.include)
|
||||
env.set("CUDNN_LIBRARY", self.spec["cudnn"].libs[0])
|
||||
|
||||
# Flash attention has very high memory requirements that may cause the build to fail
|
||||
# https://github.com/pytorch/pytorch/issues/111526
|
||||
# https://github.com/pytorch/pytorch/issues/124018
|
||||
env.set("USE_FLASH_ATTENTION", "OFF")
|
||||
|
||||
enable_or_disable("fbgemm")
|
||||
enable_or_disable("kineto")
|
||||
enable_or_disable("magma")
|
||||
enable_or_disable("metal")
|
||||
enable_or_disable("mps")
|
||||
enable_or_disable("breakpad")
|
||||
enable_or_disable("flash_attention")
|
||||
|
||||
enable_or_disable("nccl")
|
||||
if "+cuda+nccl" in self.spec:
|
||||
|
Loading…
Reference in New Issue
Block a user