Update: Fast GeLU Approximation (#744)

* add: fast gelu approx

* fix docs

* Update gelu_fast_approx function documentation

* Update python/mlx/nn/layers/activations.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* fix: test gelu

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
Noah Kasmanoff 2024-02-27 00:08:50 -05:00 committed by GitHub
parent fe1dabf272
commit de3d2467a3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 7 additions and 3 deletions

View File

@ -185,11 +185,15 @@ def gelu_fast_approx(x):
.. math::
x = x \sigma\left(1.773 x\right)
x = x \sigma\left(1.702 x\right)
where :math:`\sigma(\cdot)` is the logistic sigmoid.
References:
- https://github.com/hendrycks/GELUs
- https://arxiv.org/abs/1606.08415
"""
return x * mx.sigmoid(1.773 * x)
return x * mx.sigmoid(1.702 * x)
def glu(x: mx.array, axis: int = -1) -> mx.array:

View File

@ -665,7 +665,7 @@ class TestLayers(mlx_tests.MLXTestCase):
y_hat1 = nn.gelu_approx(x)
y_hat2 = nn.gelu_fast_approx(x)
self.assertLess(mx.abs(y - y_hat1).max(), 0.0003)
self.assertLess(mx.abs(y - y_hat2).max(), 0.02)
self.assertLess(mx.abs(y - y_hat2).max(), 0.025)
def test_sin_pe(self):
m = nn.SinusoidalPositionalEncoding(16, min_freq=0.01)