feat: Added "tanh" option to GELU approximation (#1268)

This commit is contained in:
Atakan Tekparmak 2024-07-28 09:07:56 +02:00 committed by GitHub
parent 8cfb9fc0b8
commit 6e06e3a904
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 13 additions and 3 deletions

View File

@ -536,12 +536,17 @@ class GELU(Module):
.. math::
\textrm{GELUApprox}(x) &= 0.5 * x * \left(1 + \text{Tanh}\left((\sqrt{2 / \pi} * \left(x + 0.044715 * x^3\right)\right)\right) \\
\textrm{GELUFast}(x) &= x * \sigma\left(1.773 * x\right)
\textrm{GELUFast}(x) &= x * \sigma\left(1.702 * x\right)
respectively.
.. note::
For compatibility with the PyTorch API, 'tanh' can be used as an alias
for 'precise'.
See :func:`gelu`, :func:`gelu_approx` and :func:`gelu_fast_approx` for the
functional equivalents and information regarding error bounds.
Args:
approx ('none' | 'precise' | 'fast'): Which approximation to gelu to use if any.
@ -552,13 +557,13 @@ class GELU(Module):
if approx == "none":
self._act = gelu
elif approx == "precise":
elif approx == "precise" or approx == "tanh":
self._act = gelu_approx
elif approx == "fast":
self._act = gelu_fast_approx
else:
raise ValueError(
f"The approximation should be in ['none', 'precise', 'fast'] but '{approx}' was given"
f"The approximation should be in ['none', 'precise', 'tanh', 'fast'] but '{approx}' was given"
)
def __call__(self, x):

View File

@ -722,8 +722,13 @@ class TestLayers(mlx_tests.MLXTestCase):
out = nn.GELU()(mx.array(inputs))
self.assertTrue(np.allclose(out, expected))
# Test the precise/tanh approximation
out_approx = nn.GELU(approx="precise")(mx.array(inputs))
out_approx_tanh = nn.GELU(approx="tanh")(mx.array(inputs))
self.assertTrue(np.allclose(out_approx, expected_approx))
self.assertTrue(np.allclose(out_approx_tanh, expected_approx))
self.assertTrue(np.allclose(out_approx, out_approx_tanh))
# Crudely check the approximations
x = mx.arange(-6.0, 6.0, 12 / 100)