diff --git a/python/mlx/nn/layers/activations.py b/python/mlx/nn/layers/activations.py index 8eafd75d3a..5bff0ad14b 100644 --- a/python/mlx/nn/layers/activations.py +++ b/python/mlx/nn/layers/activations.py @@ -132,6 +132,20 @@ def silu(x): return x * mx.sigmoid(x) +@partial(mx.compile, shapeless=True) +def clipped_silu(x, a_min=-100, a_max=100): + r"""Applies the Clipped Sigmoid Linear Unit. + + Applies :math:`\text{clip}(x \sigma(x), a\_min, a\_max)` element wise, where + :math:`\sigma(\cdot)` is the logistic sigmoid. + + Args: + a_min: minimum value for clipping. Default: ``-100`` + a_max: maximum value for clipping. Default: ``100`` + """ + return mx.clip(x * mx.sigmoid(x), a_min=a_min, a_max=a_max) + + @partial(mx.compile, shapeless=True) def log_sigmoid(x): r"""Applies the Log Sigmoid function. @@ -488,6 +502,25 @@ class SiLU(Module): """ +class ClippedSiLU(Module): + r"""Applies the Clipped Sigmoid Linear Unit. + + See :func:`clipped_silu` for the functional equivalent. + + Args: + a_min: minimum value for clipping. Default: ``-100`` + a_max: maximum value for clipping. Default: ``100`` + """ + + def __init__(self, a_min=-100, a_max=100): + super().__init__() + self.a_min = a_min + self.a_max = a_max + + def __call__(self, x): + return clipped_silu(x, self.a_min, self.a_max) + + @_make_activation_module(log_softmax) class LogSoftmax(Module): r"""Applies the Log Softmax function.