diff --git a/python/mlx/nn/layers/activations.py b/python/mlx/nn/layers/activations.py index cb1460ef3..2b5ded696 100644 --- a/python/mlx/nn/layers/activations.py +++ b/python/mlx/nn/layers/activations.py @@ -129,3 +129,16 @@ class GELU(Module): def __call__(self, x): return self._act(x) + + +def tanh(x): + """Applies the hyperbolic tangent function. + + Simply ``mx.tanh(x)``. + """ + return mx.tanh(x) + + +@_make_activation_module(tanh) +class Tanh(Module): + pass