mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-02 13:54:44 +08:00
feat: Added "tanh" option to GELU approximation (#1268)
This commit is contained in:
@@ -722,8 +722,13 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
|
||||
out = nn.GELU()(mx.array(inputs))
|
||||
self.assertTrue(np.allclose(out, expected))
|
||||
|
||||
# Test the precise/tanh approximation
|
||||
out_approx = nn.GELU(approx="precise")(mx.array(inputs))
|
||||
out_approx_tanh = nn.GELU(approx="tanh")(mx.array(inputs))
|
||||
self.assertTrue(np.allclose(out_approx, expected_approx))
|
||||
self.assertTrue(np.allclose(out_approx_tanh, expected_approx))
|
||||
self.assertTrue(np.allclose(out_approx, out_approx_tanh))
|
||||
|
||||
# Crudely check the approximations
|
||||
x = mx.arange(-6.0, 6.0, 12 / 100)
|
||||
|
Reference in New Issue
Block a user