gelu tanh approx (#989)

* gelu tanh approx

* gelu tanh approx

* replace gelu approx with tanh approach

* fix comments

* fix comment
This commit is contained in:
Shiyu
2024-04-16 10:49:00 +08:00
committed by GitHub
parent cd9e184529
commit 107ba2891a
2 changed files with 11 additions and 6 deletions

View File

@@ -163,15 +163,14 @@ def gelu_approx(x):
See :func:`gelu` for the exact computation.
This function approximates ``gelu`` with a maximum absolute error :math:`<
0.0003` in the range :math:`[-6, 6]` using the following
0.0005` in the range :math:`[-6, 6]` using the following
.. math::
x = x \sigma\left(1.60033 x \left(1 + 0.0433603 x^2\right)\right)
x = 0.5 * x * \left(1 + \text{Tanh}\left((\sqrt{2 / \pi} * \left(x + 0.044715 * x^3\right)\right)\right)
where :math:`\sigma(\cdot)` is the logistic sigmoid.
"""
return x * mx.sigmoid(1.60033 * x * (1 + 0.0433603 * x.square()))
return 0.5 * x * (1 + mx.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * x**3)))
@partial(mx.compile, shapeless=True)
@@ -504,7 +503,7 @@ class GELU(Module):
However, if ``approx`` is set to 'precise' or 'fast' it applies
.. math::
\textrm{GELUApprox}(x) &= x * \sigma\left(1.60033 * x \left(1 + 0.0433603 * x^2\right)\right) \\
\textrm{GELUApprox}(x) &= 0.5 * x * \left(1 + \text{Tanh}\left((\sqrt{2 / \pi} * \left(x + 0.044715 * x^3\right)\right)\right) \\
\textrm{GELUFast}(x) &= x * \sigma\left(1.773 * x\right)
respectively.