diff --git a/python/mlx/nn/layers/activations.py b/python/mlx/nn/layers/activations.py index 178cbc7b1..fe9014adc 100644 --- a/python/mlx/nn/layers/activations.py +++ b/python/mlx/nn/layers/activations.py @@ -185,11 +185,15 @@ def gelu_fast_approx(x): .. math:: - x = x \sigma\left(1.773 x\right) + x = x \sigma\left(1.702 x\right) where :math:`\sigma(\cdot)` is the logistic sigmoid. + + References: + - https://github.com/hendrycks/GELUs + - https://arxiv.org/abs/1606.08415 """ - return x * mx.sigmoid(1.773 * x) + return x * mx.sigmoid(1.702 * x) def glu(x: mx.array, axis: int = -1) -> mx.array: diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 2c8346179..99154d3f6 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -665,7 +665,7 @@ class TestLayers(mlx_tests.MLXTestCase): y_hat1 = nn.gelu_approx(x) y_hat2 = nn.gelu_fast_approx(x) self.assertLess(mx.abs(y - y_hat1).max(), 0.0003) - self.assertLess(mx.abs(y - y_hat2).max(), 0.02) + self.assertLess(mx.abs(y - y_hat2).max(), 0.025) def test_sin_pe(self): m = nn.SinusoidalPositionalEncoding(16, min_freq=0.01)