mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 02:38:09 +08:00
feat: Added "tanh" option to GELU approximation (#1268)
This commit is contained in:
@@ -536,12 +536,17 @@ class GELU(Module):
|
||||
|
||||
.. math::
|
||||
\textrm{GELUApprox}(x) &= 0.5 * x * \left(1 + \text{Tanh}\left((\sqrt{2 / \pi} * \left(x + 0.044715 * x^3\right)\right)\right) \\
|
||||
\textrm{GELUFast}(x) &= x * \sigma\left(1.773 * x\right)
|
||||
\textrm{GELUFast}(x) &= x * \sigma\left(1.702 * x\right)
|
||||
|
||||
respectively.
|
||||
|
||||
.. note::
|
||||
For compatibility with the PyTorch API, 'tanh' can be used as an alias
|
||||
for 'precise'.
|
||||
|
||||
See :func:`gelu`, :func:`gelu_approx` and :func:`gelu_fast_approx` for the
|
||||
functional equivalents and information regarding error bounds.
|
||||
|
||||
|
||||
Args:
|
||||
approx ('none' | 'precise' | 'fast'): Which approximation to gelu to use if any.
|
||||
@@ -552,13 +557,13 @@ class GELU(Module):
|
||||
|
||||
if approx == "none":
|
||||
self._act = gelu
|
||||
elif approx == "precise":
|
||||
elif approx == "precise" or approx == "tanh":
|
||||
self._act = gelu_approx
|
||||
elif approx == "fast":
|
||||
self._act = gelu_fast_approx
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The approximation should be in ['none', 'precise', 'fast'] but '{approx}' was given"
|
||||
f"The approximation should be in ['none', 'precise', 'tanh', 'fast'] but '{approx}' was given"
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
|
Reference in New Issue
Block a user