feat: Added "tanh" option to GELU approximation (#1268)

This commit is contained in:
Atakan Tekparmak
2024-07-28 09:07:56 +02:00
committed by GitHub
parent 8cfb9fc0b8
commit 6e06e3a904
2 changed files with 13 additions and 3 deletions

View File

@@ -536,12 +536,17 @@ class GELU(Module):
.. math::
\textrm{GELUApprox}(x) &= 0.5 * x * \left(1 + \text{Tanh}\left((\sqrt{2 / \pi} * \left(x + 0.044715 * x^3\right)\right)\right) \\
\textrm{GELUFast}(x) &= x * \sigma\left(1.773 * x\right)
\textrm{GELUFast}(x) &= x * \sigma\left(1.702 * x\right)
respectively.
.. note::
For compatibility with the PyTorch API, 'tanh' can be used as an alias
for 'precise'.
See :func:`gelu`, :func:`gelu_approx` and :func:`gelu_fast_approx` for the
functional equivalents and information regarding error bounds.
Args:
approx ('none' | 'precise' | 'fast'): Which approximation to gelu to use if any.
@@ -552,13 +557,13 @@ class GELU(Module):
if approx == "none":
self._act = gelu
elif approx == "precise":
elif approx == "precise" or approx == "tanh":
self._act = gelu_approx
elif approx == "fast":
self._act = gelu_fast_approx
else:
raise ValueError(
f"The approximation should be in ['none', 'precise', 'fast'] but '{approx}' was given"
f"The approximation should be in ['none', 'precise', 'tanh', 'fast'] but '{approx}' was given"
)
def __call__(self, x):