diff --git a/python/mlx/nn/layers/activations.py b/python/mlx/nn/layers/activations.py index 962022e19..ba565e02d 100644 --- a/python/mlx/nn/layers/activations.py +++ b/python/mlx/nn/layers/activations.py @@ -163,15 +163,14 @@ def gelu_approx(x): See :func:`gelu` for the exact computation. This function approximates ``gelu`` with a maximum absolute error :math:`< - 0.0003` in the range :math:`[-6, 6]` using the following + 0.0005` in the range :math:`[-6, 6]` using the following .. math:: - x = x \sigma\left(1.60033 x \left(1 + 0.0433603 x^2\right)\right) + x = 0.5 * x * \left(1 + \text{Tanh}\left((\sqrt{2 / \pi} * \left(x + 0.044715 * x^3\right)\right)\right) - where :math:`\sigma(\cdot)` is the logistic sigmoid. """ - return x * mx.sigmoid(1.60033 * x * (1 + 0.0433603 * x.square())) + return 0.5 * x * (1 + mx.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * x**3))) @partial(mx.compile, shapeless=True) @@ -504,7 +503,7 @@ class GELU(Module): However, if ``approx`` is set to 'precise' or 'fast' it applies .. math:: - \textrm{GELUApprox}(x) &= x * \sigma\left(1.60033 * x \left(1 + 0.0433603 * x^2\right)\right) \\ + \textrm{GELUApprox}(x) &= 0.5 * x * \left(1 + \text{Tanh}\left((\sqrt{2 / \pi} * \left(x + 0.044715 * x^3\right)\right)\right) \\ \textrm{GELUFast}(x) &= x * \sigma\left(1.773 * x\right) respectively. diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 4ca5e465b..06c974652 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -717,16 +717,22 @@ class TestLayers(mlx_tests.MLXTestCase): expected = np.array( [1.0093501, -0.16925684, 0.22918941, 0.60498625, 0.49459383] ) + # From: jax.nn.gelu(np.array(inputs), approximate=True) + expected_approx = np.array( + [1.0091482, -0.1693441, 0.22918446, 0.60491, 0.4945476] + ) out = nn.GELU()(mx.array(inputs)) self.assertTrue(np.allclose(out, expected)) + out_approx = nn.GELU(approx="precise")(mx.array(inputs)) + self.assertTrue(np.allclose(out_approx, expected_approx)) # Crudely check the approximations x = mx.arange(-6.0, 6.0, 12 / 100) y = nn.gelu(x) y_hat1 = nn.gelu_approx(x) y_hat2 = nn.gelu_fast_approx(x) - self.assertLess(mx.abs(y - y_hat1).max(), 0.0003) + self.assertLess(mx.abs(y - y_hat1).max(), 0.0005) self.assertLess(mx.abs(y - y_hat2).max(), 0.025) def test_sin_pe(self):