mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Update: Fast GeLU Approximation (#744)
* add: fast gelu approx * fix docs * Update gelu_fast_approx function documentation * Update python/mlx/nn/layers/activations.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * fix: test gelu --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
parent
fe1dabf272
commit
de3d2467a3
@ -185,11 +185,15 @@ def gelu_fast_approx(x):
|
||||
|
||||
.. math::
|
||||
|
||||
x = x \sigma\left(1.773 x\right)
|
||||
x = x \sigma\left(1.702 x\right)
|
||||
|
||||
where :math:`\sigma(\cdot)` is the logistic sigmoid.
|
||||
|
||||
References:
|
||||
- https://github.com/hendrycks/GELUs
|
||||
- https://arxiv.org/abs/1606.08415
|
||||
"""
|
||||
return x * mx.sigmoid(1.773 * x)
|
||||
return x * mx.sigmoid(1.702 * x)
|
||||
|
||||
|
||||
def glu(x: mx.array, axis: int = -1) -> mx.array:
|
||||
|
@ -665,7 +665,7 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
y_hat1 = nn.gelu_approx(x)
|
||||
y_hat2 = nn.gelu_fast_approx(x)
|
||||
self.assertLess(mx.abs(y - y_hat1).max(), 0.0003)
|
||||
self.assertLess(mx.abs(y - y_hat2).max(), 0.02)
|
||||
self.assertLess(mx.abs(y - y_hat2).max(), 0.025)
|
||||
|
||||
def test_sin_pe(self):
|
||||
m = nn.SinusoidalPositionalEncoding(16, min_freq=0.01)
|
||||
|
Loading…
Reference in New Issue
Block a user