mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-20 01:46:37 +08:00
feat: Added "tanh" option to GELU approximation (#1268)
This commit is contained in:
parent
8cfb9fc0b8
commit
6e06e3a904
@ -536,12 +536,17 @@ class GELU(Module):
|
||||
|
||||
.. math::
|
||||
\textrm{GELUApprox}(x) &= 0.5 * x * \left(1 + \text{Tanh}\left((\sqrt{2 / \pi} * \left(x + 0.044715 * x^3\right)\right)\right) \\
|
||||
\textrm{GELUFast}(x) &= x * \sigma\left(1.773 * x\right)
|
||||
\textrm{GELUFast}(x) &= x * \sigma\left(1.702 * x\right)
|
||||
|
||||
respectively.
|
||||
|
||||
.. note::
|
||||
For compatibility with the PyTorch API, 'tanh' can be used as an alias
|
||||
for 'precise'.
|
||||
|
||||
See :func:`gelu`, :func:`gelu_approx` and :func:`gelu_fast_approx` for the
|
||||
functional equivalents and information regarding error bounds.
|
||||
|
||||
|
||||
Args:
|
||||
approx ('none' | 'precise' | 'fast'): Which approximation to gelu to use if any.
|
||||
@ -552,13 +557,13 @@ class GELU(Module):
|
||||
|
||||
if approx == "none":
|
||||
self._act = gelu
|
||||
elif approx == "precise":
|
||||
elif approx == "precise" or approx == "tanh":
|
||||
self._act = gelu_approx
|
||||
elif approx == "fast":
|
||||
self._act = gelu_fast_approx
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The approximation should be in ['none', 'precise', 'fast'] but '{approx}' was given"
|
||||
f"The approximation should be in ['none', 'precise', 'tanh', 'fast'] but '{approx}' was given"
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
|
@ -722,8 +722,13 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
|
||||
out = nn.GELU()(mx.array(inputs))
|
||||
self.assertTrue(np.allclose(out, expected))
|
||||
|
||||
# Test the precise/tanh approximation
|
||||
out_approx = nn.GELU(approx="precise")(mx.array(inputs))
|
||||
out_approx_tanh = nn.GELU(approx="tanh")(mx.array(inputs))
|
||||
self.assertTrue(np.allclose(out_approx, expected_approx))
|
||||
self.assertTrue(np.allclose(out_approx_tanh, expected_approx))
|
||||
self.assertTrue(np.allclose(out_approx, out_approx_tanh))
|
||||
|
||||
# Crudely check the approximations
|
||||
x = mx.arange(-6.0, 6.0, 12 / 100)
|
||||
|
Loading…
Reference in New Issue
Block a user