feat: Added "tanh" option to GELU approximation (#1268)

This commit is contained in:
Atakan Tekparmak
2024-07-28 09:07:56 +02:00
committed by GitHub
parent 8cfb9fc0b8
commit 6e06e3a904
2 changed files with 13 additions and 3 deletions

View File

@@ -722,8 +722,13 @@ class TestLayers(mlx_tests.MLXTestCase):
out = nn.GELU()(mx.array(inputs))
self.assertTrue(np.allclose(out, expected))
# Test the precise/tanh approximation
out_approx = nn.GELU(approx="precise")(mx.array(inputs))
out_approx_tanh = nn.GELU(approx="tanh")(mx.array(inputs))
self.assertTrue(np.allclose(out_approx, expected_approx))
self.assertTrue(np.allclose(out_approx_tanh, expected_approx))
self.assertTrue(np.allclose(out_approx, out_approx_tanh))
# Crudely check the approximations
x = mx.arange(-6.0, 6.0, 12 / 100)