Add transformer engine package (#43982)

* Add py-flash-attn@2.4.2
* Add py-transfomer-engine package

---------

Co-authored-by: Tamara Dahlgren <35777542+tldahlgren@users.noreply.github.com>
This commit is contained in:
Auriane R 2024-05-07 23:56:34 +02:00 committed by GitHub
parent f6d50f790e
commit 84ed4cd331
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -0,0 +1,48 @@
# Copyright 2013-2024 Lawrence Livermore National Security, LLC and other
# Spack Project Developers. See the top-level COPYRIGHT file for details.
#
# SPDX-License-Identifier: (Apache-2.0 OR MIT)
from spack.package import *
class PyTransformerEngine(PythonPackage):
"""
A library for accelerating Transformer models on NVIDIA GPUs, including fp8 precision on Hopper
GPUs.
"""
homepage = "https://github.com/NVIDIA/TransformerEngine"
url = "https://github.com/NVIDIA/TransformerEngine/archive/refs/tags/v0.0.tar.gz"
git = "https://github.com/NVIDIA/TransformerEngine.git"
maintainers("aurianer")
license("Apache-2.0")
version("1.4", tag="v1.4", submodules=True)
version("main", branch="main", submodules=True)
variant("userbuffers", default=True, description="Enable userbuffers, this option needs MPI.")
depends_on("py-setuptools", type="build")
depends_on("cmake@3.18:")
depends_on("py-pydantic")
depends_on("py-importlib-metadata")
with default_args(type=("build", "run")):
depends_on("py-accelerate")
depends_on("py-datasets")
depends_on("py-flash-attn@2.2:2.4.2")
depends_on("py-packaging")
depends_on("py-torchvision")
depends_on("py-transformers")
depends_on("mpi", when="+userbuffers")
with default_args(type=("build", "link", "run")):
depends_on("py-torch+cuda+cudnn")
def setup_build_environment(self, env):
env.set("NVTE_FRAMEWORK", "pytorch")
if self.spec.satisfies("+userbuffers"):
env.set("NVTE_WITH_USERBUFFERS", "1")
env.set("MPI_HOME", self.spec["mpi"].prefix)