mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-20 10:27:41 +08:00
feat: Added "tanh" option to GELU approximation (#1268)
This commit is contained in:
parent
8cfb9fc0b8
commit
6e06e3a904
@ -536,12 +536,17 @@ class GELU(Module):
|
|||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
\textrm{GELUApprox}(x) &= 0.5 * x * \left(1 + \text{Tanh}\left((\sqrt{2 / \pi} * \left(x + 0.044715 * x^3\right)\right)\right) \\
|
\textrm{GELUApprox}(x) &= 0.5 * x * \left(1 + \text{Tanh}\left((\sqrt{2 / \pi} * \left(x + 0.044715 * x^3\right)\right)\right) \\
|
||||||
\textrm{GELUFast}(x) &= x * \sigma\left(1.773 * x\right)
|
\textrm{GELUFast}(x) &= x * \sigma\left(1.702 * x\right)
|
||||||
|
|
||||||
respectively.
|
respectively.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
For compatibility with the PyTorch API, 'tanh' can be used as an alias
|
||||||
|
for 'precise'.
|
||||||
|
|
||||||
See :func:`gelu`, :func:`gelu_approx` and :func:`gelu_fast_approx` for the
|
See :func:`gelu`, :func:`gelu_approx` and :func:`gelu_fast_approx` for the
|
||||||
functional equivalents and information regarding error bounds.
|
functional equivalents and information regarding error bounds.
|
||||||
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
approx ('none' | 'precise' | 'fast'): Which approximation to gelu to use if any.
|
approx ('none' | 'precise' | 'fast'): Which approximation to gelu to use if any.
|
||||||
@ -552,13 +557,13 @@ class GELU(Module):
|
|||||||
|
|
||||||
if approx == "none":
|
if approx == "none":
|
||||||
self._act = gelu
|
self._act = gelu
|
||||||
elif approx == "precise":
|
elif approx == "precise" or approx == "tanh":
|
||||||
self._act = gelu_approx
|
self._act = gelu_approx
|
||||||
elif approx == "fast":
|
elif approx == "fast":
|
||||||
self._act = gelu_fast_approx
|
self._act = gelu_fast_approx
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The approximation should be in ['none', 'precise', 'fast'] but '{approx}' was given"
|
f"The approximation should be in ['none', 'precise', 'tanh', 'fast'] but '{approx}' was given"
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
|
@ -722,8 +722,13 @@ class TestLayers(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
out = nn.GELU()(mx.array(inputs))
|
out = nn.GELU()(mx.array(inputs))
|
||||||
self.assertTrue(np.allclose(out, expected))
|
self.assertTrue(np.allclose(out, expected))
|
||||||
|
|
||||||
|
# Test the precise/tanh approximation
|
||||||
out_approx = nn.GELU(approx="precise")(mx.array(inputs))
|
out_approx = nn.GELU(approx="precise")(mx.array(inputs))
|
||||||
|
out_approx_tanh = nn.GELU(approx="tanh")(mx.array(inputs))
|
||||||
self.assertTrue(np.allclose(out_approx, expected_approx))
|
self.assertTrue(np.allclose(out_approx, expected_approx))
|
||||||
|
self.assertTrue(np.allclose(out_approx_tanh, expected_approx))
|
||||||
|
self.assertTrue(np.allclose(out_approx, out_approx_tanh))
|
||||||
|
|
||||||
# Crudely check the approximations
|
# Crudely check the approximations
|
||||||
x = mx.arange(-6.0, 6.0, 12 / 100)
|
x = mx.arange(-6.0, 6.0, 12 / 100)
|
||||||
|
Loading…
Reference in New Issue
Block a user