mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Update: Fast GeLU Approximation (#744)
* add: fast gelu approx * fix docs * Update gelu_fast_approx function documentation * Update python/mlx/nn/layers/activations.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * fix: test gelu --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
parent
fe1dabf272
commit
de3d2467a3
@ -185,11 +185,15 @@ def gelu_fast_approx(x):
|
|||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
|
|
||||||
x = x \sigma\left(1.773 x\right)
|
x = x \sigma\left(1.702 x\right)
|
||||||
|
|
||||||
where :math:`\sigma(\cdot)` is the logistic sigmoid.
|
where :math:`\sigma(\cdot)` is the logistic sigmoid.
|
||||||
|
|
||||||
|
References:
|
||||||
|
- https://github.com/hendrycks/GELUs
|
||||||
|
- https://arxiv.org/abs/1606.08415
|
||||||
"""
|
"""
|
||||||
return x * mx.sigmoid(1.773 * x)
|
return x * mx.sigmoid(1.702 * x)
|
||||||
|
|
||||||
|
|
||||||
def glu(x: mx.array, axis: int = -1) -> mx.array:
|
def glu(x: mx.array, axis: int = -1) -> mx.array:
|
||||||
|
@ -665,7 +665,7 @@ class TestLayers(mlx_tests.MLXTestCase):
|
|||||||
y_hat1 = nn.gelu_approx(x)
|
y_hat1 = nn.gelu_approx(x)
|
||||||
y_hat2 = nn.gelu_fast_approx(x)
|
y_hat2 = nn.gelu_fast_approx(x)
|
||||||
self.assertLess(mx.abs(y - y_hat1).max(), 0.0003)
|
self.assertLess(mx.abs(y - y_hat1).max(), 0.0003)
|
||||||
self.assertLess(mx.abs(y - y_hat2).max(), 0.02)
|
self.assertLess(mx.abs(y - y_hat2).max(), 0.025)
|
||||||
|
|
||||||
def test_sin_pe(self):
|
def test_sin_pe(self):
|
||||||
m = nn.SinusoidalPositionalEncoding(16, min_freq=0.01)
|
m = nn.SinusoidalPositionalEncoding(16, min_freq=0.01)
|
||||||
|
Loading…
Reference in New Issue
Block a user