mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
gelu tanh approx (#989)
* gelu tanh approx * gelu tanh approx * replace gelu approx with tanh approach * fix comments * fix comment
This commit is contained in:
parent
cd9e184529
commit
107ba2891a
@ -163,15 +163,14 @@ def gelu_approx(x):
|
|||||||
See :func:`gelu` for the exact computation.
|
See :func:`gelu` for the exact computation.
|
||||||
|
|
||||||
This function approximates ``gelu`` with a maximum absolute error :math:`<
|
This function approximates ``gelu`` with a maximum absolute error :math:`<
|
||||||
0.0003` in the range :math:`[-6, 6]` using the following
|
0.0005` in the range :math:`[-6, 6]` using the following
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
|
|
||||||
x = x \sigma\left(1.60033 x \left(1 + 0.0433603 x^2\right)\right)
|
x = 0.5 * x * \left(1 + \text{Tanh}\left((\sqrt{2 / \pi} * \left(x + 0.044715 * x^3\right)\right)\right)
|
||||||
|
|
||||||
where :math:`\sigma(\cdot)` is the logistic sigmoid.
|
|
||||||
"""
|
"""
|
||||||
return x * mx.sigmoid(1.60033 * x * (1 + 0.0433603 * x.square()))
|
return 0.5 * x * (1 + mx.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * x**3)))
|
||||||
|
|
||||||
|
|
||||||
@partial(mx.compile, shapeless=True)
|
@partial(mx.compile, shapeless=True)
|
||||||
@ -504,7 +503,7 @@ class GELU(Module):
|
|||||||
However, if ``approx`` is set to 'precise' or 'fast' it applies
|
However, if ``approx`` is set to 'precise' or 'fast' it applies
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
\textrm{GELUApprox}(x) &= x * \sigma\left(1.60033 * x \left(1 + 0.0433603 * x^2\right)\right) \\
|
\textrm{GELUApprox}(x) &= 0.5 * x * \left(1 + \text{Tanh}\left((\sqrt{2 / \pi} * \left(x + 0.044715 * x^3\right)\right)\right) \\
|
||||||
\textrm{GELUFast}(x) &= x * \sigma\left(1.773 * x\right)
|
\textrm{GELUFast}(x) &= x * \sigma\left(1.773 * x\right)
|
||||||
|
|
||||||
respectively.
|
respectively.
|
||||||
|
@ -717,16 +717,22 @@ class TestLayers(mlx_tests.MLXTestCase):
|
|||||||
expected = np.array(
|
expected = np.array(
|
||||||
[1.0093501, -0.16925684, 0.22918941, 0.60498625, 0.49459383]
|
[1.0093501, -0.16925684, 0.22918941, 0.60498625, 0.49459383]
|
||||||
)
|
)
|
||||||
|
# From: jax.nn.gelu(np.array(inputs), approximate=True)
|
||||||
|
expected_approx = np.array(
|
||||||
|
[1.0091482, -0.1693441, 0.22918446, 0.60491, 0.4945476]
|
||||||
|
)
|
||||||
|
|
||||||
out = nn.GELU()(mx.array(inputs))
|
out = nn.GELU()(mx.array(inputs))
|
||||||
self.assertTrue(np.allclose(out, expected))
|
self.assertTrue(np.allclose(out, expected))
|
||||||
|
out_approx = nn.GELU(approx="precise")(mx.array(inputs))
|
||||||
|
self.assertTrue(np.allclose(out_approx, expected_approx))
|
||||||
|
|
||||||
# Crudely check the approximations
|
# Crudely check the approximations
|
||||||
x = mx.arange(-6.0, 6.0, 12 / 100)
|
x = mx.arange(-6.0, 6.0, 12 / 100)
|
||||||
y = nn.gelu(x)
|
y = nn.gelu(x)
|
||||||
y_hat1 = nn.gelu_approx(x)
|
y_hat1 = nn.gelu_approx(x)
|
||||||
y_hat2 = nn.gelu_fast_approx(x)
|
y_hat2 = nn.gelu_fast_approx(x)
|
||||||
self.assertLess(mx.abs(y - y_hat1).max(), 0.0003)
|
self.assertLess(mx.abs(y - y_hat1).max(), 0.0005)
|
||||||
self.assertLess(mx.abs(y - y_hat2).max(), 0.025)
|
self.assertLess(mx.abs(y - y_hat2).max(), 0.025)
|
||||||
|
|
||||||
def test_sin_pe(self):
|
def test_sin_pe(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user