diff --git a/python/mlx/nn/layers/activations.py b/python/mlx/nn/layers/activations.py index ff03a29ed..8eafd75d3 100644 --- a/python/mlx/nn/layers/activations.py +++ b/python/mlx/nn/layers/activations.py @@ -536,12 +536,17 @@ class GELU(Module): .. math:: \textrm{GELUApprox}(x) &= 0.5 * x * \left(1 + \text{Tanh}\left((\sqrt{2 / \pi} * \left(x + 0.044715 * x^3\right)\right)\right) \\ - \textrm{GELUFast}(x) &= x * \sigma\left(1.773 * x\right) + \textrm{GELUFast}(x) &= x * \sigma\left(1.702 * x\right) respectively. + .. note:: + For compatibility with the PyTorch API, 'tanh' can be used as an alias + for 'precise'. + See :func:`gelu`, :func:`gelu_approx` and :func:`gelu_fast_approx` for the functional equivalents and information regarding error bounds. + Args: approx ('none' | 'precise' | 'fast'): Which approximation to gelu to use if any. @@ -552,13 +557,13 @@ class GELU(Module): if approx == "none": self._act = gelu - elif approx == "precise": + elif approx == "precise" or approx == "tanh": self._act = gelu_approx elif approx == "fast": self._act = gelu_fast_approx else: raise ValueError( - f"The approximation should be in ['none', 'precise', 'fast'] but '{approx}' was given" + f"The approximation should be in ['none', 'precise', 'tanh', 'fast'] but '{approx}' was given" ) def __call__(self, x): diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 70b2eecc4..f3d47cf26 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -722,8 +722,13 @@ class TestLayers(mlx_tests.MLXTestCase): out = nn.GELU()(mx.array(inputs)) self.assertTrue(np.allclose(out, expected)) + + # Test the precise/tanh approximation out_approx = nn.GELU(approx="precise")(mx.array(inputs)) + out_approx_tanh = nn.GELU(approx="tanh")(mx.array(inputs)) self.assertTrue(np.allclose(out_approx, expected_approx)) + self.assertTrue(np.allclose(out_approx_tanh, expected_approx)) + self.assertTrue(np.allclose(out_approx, out_approx_tanh)) # Crudely check the approximations x = mx.arange(-6.0, 6.0, 12 / 100)