mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-07 17:44:38 +08:00
gelu tanh approx (#989)
* gelu tanh approx * gelu tanh approx * replace gelu approx with tanh approach * fix comments * fix comment
This commit is contained in:
@@ -717,16 +717,22 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
expected = np.array(
|
||||
[1.0093501, -0.16925684, 0.22918941, 0.60498625, 0.49459383]
|
||||
)
|
||||
# From: jax.nn.gelu(np.array(inputs), approximate=True)
|
||||
expected_approx = np.array(
|
||||
[1.0091482, -0.1693441, 0.22918446, 0.60491, 0.4945476]
|
||||
)
|
||||
|
||||
out = nn.GELU()(mx.array(inputs))
|
||||
self.assertTrue(np.allclose(out, expected))
|
||||
out_approx = nn.GELU(approx="precise")(mx.array(inputs))
|
||||
self.assertTrue(np.allclose(out_approx, expected_approx))
|
||||
|
||||
# Crudely check the approximations
|
||||
x = mx.arange(-6.0, 6.0, 12 / 100)
|
||||
y = nn.gelu(x)
|
||||
y_hat1 = nn.gelu_approx(x)
|
||||
y_hat2 = nn.gelu_fast_approx(x)
|
||||
self.assertLess(mx.abs(y - y_hat1).max(), 0.0003)
|
||||
self.assertLess(mx.abs(y - y_hat1).max(), 0.0005)
|
||||
self.assertLess(mx.abs(y - y_hat2).max(), 0.025)
|
||||
|
||||
def test_sin_pe(self):
|
||||
|
Reference in New Issue
Block a user