PyTorch: build flash attention by default, except in CI (#48521)
* PyTorch: build flash attention by default, except in CI * Variant is boolean, only available when +cuda/+rocm * desc -> _desc
This commit is contained in:
parent
18cd922aab
commit
46f5b192ef
@ -12,6 +12,13 @@ spack:
|
|||||||
require: ~cuda
|
require: ~cuda
|
||||||
mpi:
|
mpi:
|
||||||
require: openmpi
|
require: openmpi
|
||||||
|
py-torch:
|
||||||
|
require:
|
||||||
|
- target=aarch64
|
||||||
|
- ~rocm
|
||||||
|
- +cuda
|
||||||
|
- cuda_arch=80
|
||||||
|
- ~flash_attention
|
||||||
|
|
||||||
specs:
|
specs:
|
||||||
# Horovod
|
# Horovod
|
||||||
|
@ -12,6 +12,13 @@ spack:
|
|||||||
require: ~cuda
|
require: ~cuda
|
||||||
mpi:
|
mpi:
|
||||||
require: openmpi
|
require: openmpi
|
||||||
|
py-torch:
|
||||||
|
require:
|
||||||
|
- target=x86_64_v3
|
||||||
|
- ~rocm
|
||||||
|
- +cuda
|
||||||
|
- cuda_arch=80
|
||||||
|
- ~flash_attention
|
||||||
|
|
||||||
specs:
|
specs:
|
||||||
# Horovod
|
# Horovod
|
||||||
|
@ -11,6 +11,13 @@ spack:
|
|||||||
require: "osmesa"
|
require: "osmesa"
|
||||||
mpi:
|
mpi:
|
||||||
require: openmpi
|
require: openmpi
|
||||||
|
py-torch:
|
||||||
|
require:
|
||||||
|
- target=x86_64_v3
|
||||||
|
- ~cuda
|
||||||
|
- +rocm
|
||||||
|
- amdgpu_target=gfx90a
|
||||||
|
- ~flash_attention
|
||||||
|
|
||||||
specs:
|
specs:
|
||||||
# Horovod
|
# Horovod
|
||||||
|
@ -114,6 +114,12 @@ class PyTorch(PythonPackage, CudaPackage, ROCmPackage):
|
|||||||
description="Enable breakpad crash dump library",
|
description="Enable breakpad crash dump library",
|
||||||
when="@1.10:1.11",
|
when="@1.10:1.11",
|
||||||
)
|
)
|
||||||
|
# Flash attention has very high memory requirements that may cause the build to fail
|
||||||
|
# https://github.com/pytorch/pytorch/issues/111526
|
||||||
|
# https://github.com/pytorch/pytorch/issues/124018
|
||||||
|
_desc = "Build the flash_attention kernel for scaled dot product attention"
|
||||||
|
variant("flash_attention", default=True, description=_desc, when="@1.13:+cuda")
|
||||||
|
variant("flash_attention", default=True, description=_desc, when="@1.13:+rocm")
|
||||||
# py-torch has strict dependencies on old protobuf/py-protobuf versions that
|
# py-torch has strict dependencies on old protobuf/py-protobuf versions that
|
||||||
# cause problems with other packages that require newer versions of protobuf
|
# cause problems with other packages that require newer versions of protobuf
|
||||||
# and py-protobuf --> provide an option to use the internal/vendored protobuf.
|
# and py-protobuf --> provide an option to use the internal/vendored protobuf.
|
||||||
@ -594,17 +600,13 @@ def enable_or_disable(variant, keyword="USE", var=None):
|
|||||||
env.set("CUDNN_INCLUDE_DIR", self.spec["cudnn"].prefix.include)
|
env.set("CUDNN_INCLUDE_DIR", self.spec["cudnn"].prefix.include)
|
||||||
env.set("CUDNN_LIBRARY", self.spec["cudnn"].libs[0])
|
env.set("CUDNN_LIBRARY", self.spec["cudnn"].libs[0])
|
||||||
|
|
||||||
# Flash attention has very high memory requirements that may cause the build to fail
|
|
||||||
# https://github.com/pytorch/pytorch/issues/111526
|
|
||||||
# https://github.com/pytorch/pytorch/issues/124018
|
|
||||||
env.set("USE_FLASH_ATTENTION", "OFF")
|
|
||||||
|
|
||||||
enable_or_disable("fbgemm")
|
enable_or_disable("fbgemm")
|
||||||
enable_or_disable("kineto")
|
enable_or_disable("kineto")
|
||||||
enable_or_disable("magma")
|
enable_or_disable("magma")
|
||||||
enable_or_disable("metal")
|
enable_or_disable("metal")
|
||||||
enable_or_disable("mps")
|
enable_or_disable("mps")
|
||||||
enable_or_disable("breakpad")
|
enable_or_disable("breakpad")
|
||||||
|
enable_or_disable("flash_attention")
|
||||||
|
|
||||||
enable_or_disable("nccl")
|
enable_or_disable("nccl")
|
||||||
if "+cuda+nccl" in self.spec:
|
if "+cuda+nccl" in self.spec:
|
||||||
|
Loading…
Reference in New Issue
Block a user