diff --git a/share/spack/gitlab/cloud_pipelines/stacks/ml-linux-aarch64-cuda/spack.yaml b/share/spack/gitlab/cloud_pipelines/stacks/ml-linux-aarch64-cuda/spack.yaml index 47f4eda0f12..d37f8ec7d27 100644 --- a/share/spack/gitlab/cloud_pipelines/stacks/ml-linux-aarch64-cuda/spack.yaml +++ b/share/spack/gitlab/cloud_pipelines/stacks/ml-linux-aarch64-cuda/spack.yaml @@ -12,6 +12,13 @@ spack: require: ~cuda mpi: require: openmpi + py-torch: + require: + - target=aarch64 + - ~rocm + - +cuda + - cuda_arch=80 + - ~flash_attention specs: # Horovod diff --git a/share/spack/gitlab/cloud_pipelines/stacks/ml-linux-x86_64-cuda/spack.yaml b/share/spack/gitlab/cloud_pipelines/stacks/ml-linux-x86_64-cuda/spack.yaml index 05b570f8f9a..7b508debd37 100644 --- a/share/spack/gitlab/cloud_pipelines/stacks/ml-linux-x86_64-cuda/spack.yaml +++ b/share/spack/gitlab/cloud_pipelines/stacks/ml-linux-x86_64-cuda/spack.yaml @@ -12,6 +12,13 @@ spack: require: ~cuda mpi: require: openmpi + py-torch: + require: + - target=x86_64_v3 + - ~rocm + - +cuda + - cuda_arch=80 + - ~flash_attention specs: # Horovod diff --git a/share/spack/gitlab/cloud_pipelines/stacks/ml-linux-x86_64-rocm/spack.yaml b/share/spack/gitlab/cloud_pipelines/stacks/ml-linux-x86_64-rocm/spack.yaml index 4b66256255a..f3bdf578fa1 100644 --- a/share/spack/gitlab/cloud_pipelines/stacks/ml-linux-x86_64-rocm/spack.yaml +++ b/share/spack/gitlab/cloud_pipelines/stacks/ml-linux-x86_64-rocm/spack.yaml @@ -11,6 +11,13 @@ spack: require: "osmesa" mpi: require: openmpi + py-torch: + require: + - target=x86_64_v3 + - ~cuda + - +rocm + - amdgpu_target=gfx90a + - ~flash_attention specs: # Horovod diff --git a/var/spack/repos/builtin/packages/py-torch/package.py b/var/spack/repos/builtin/packages/py-torch/package.py index 57e7738d8d0..c0ad2f065e8 100644 --- a/var/spack/repos/builtin/packages/py-torch/package.py +++ b/var/spack/repos/builtin/packages/py-torch/package.py @@ -114,6 +114,12 @@ class PyTorch(PythonPackage, CudaPackage, ROCmPackage): description="Enable breakpad crash dump library", when="@1.10:1.11", ) + # Flash attention has very high memory requirements that may cause the build to fail + # https://github.com/pytorch/pytorch/issues/111526 + # https://github.com/pytorch/pytorch/issues/124018 + _desc = "Build the flash_attention kernel for scaled dot product attention" + variant("flash_attention", default=True, description=_desc, when="@1.13:+cuda") + variant("flash_attention", default=True, description=_desc, when="@1.13:+rocm") # py-torch has strict dependencies on old protobuf/py-protobuf versions that # cause problems with other packages that require newer versions of protobuf # and py-protobuf --> provide an option to use the internal/vendored protobuf. @@ -594,17 +600,13 @@ def enable_or_disable(variant, keyword="USE", var=None): env.set("CUDNN_INCLUDE_DIR", self.spec["cudnn"].prefix.include) env.set("CUDNN_LIBRARY", self.spec["cudnn"].libs[0]) - # Flash attention has very high memory requirements that may cause the build to fail - # https://github.com/pytorch/pytorch/issues/111526 - # https://github.com/pytorch/pytorch/issues/124018 - env.set("USE_FLASH_ATTENTION", "OFF") - enable_or_disable("fbgemm") enable_or_disable("kineto") enable_or_disable("magma") enable_or_disable("metal") enable_or_disable("mps") enable_or_disable("breakpad") + enable_or_disable("flash_attention") enable_or_disable("nccl") if "+cuda+nccl" in self.spec: