From 60cd4a5a6ff42f72c182f513c622a1c862d9d353 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Mon, 16 Jun 2025 22:33:24 +0200 Subject: [PATCH] initial commit --- python/mlx/nn/layers/activations.py | 33 +++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/python/mlx/nn/layers/activations.py b/python/mlx/nn/layers/activations.py index 8eafd75d3a..5bff0ad14b 100644 --- a/python/mlx/nn/layers/activations.py +++ b/python/mlx/nn/layers/activations.py @@ -132,6 +132,20 @@ def silu(x): return x * mx.sigmoid(x) +@partial(mx.compile, shapeless=True) +def clipped_silu(x, a_min=-100, a_max=100): + r"""Applies the Clipped Sigmoid Linear Unit. + + Applies :math:`\text{clip}(x \sigma(x), a\_min, a\_max)` element wise, where + :math:`\sigma(\cdot)` is the logistic sigmoid. + + Args: + a_min: minimum value for clipping. Default: ``-100`` + a_max: maximum value for clipping. Default: ``100`` + """ + return mx.clip(x * mx.sigmoid(x), a_min=a_min, a_max=a_max) + + @partial(mx.compile, shapeless=True) def log_sigmoid(x): r"""Applies the Log Sigmoid function. @@ -488,6 +502,25 @@ class SiLU(Module): """ +class ClippedSiLU(Module): + r"""Applies the Clipped Sigmoid Linear Unit. + + See :func:`clipped_silu` for the functional equivalent. + + Args: + a_min: minimum value for clipping. Default: ``-100`` + a_max: maximum value for clipping. Default: ``100`` + """ + + def __init__(self, a_min=-100, a_max=100): + super().__init__() + self.a_min = a_min + self.a_max = a_max + + def __call__(self, x): + return clipped_silu(x, self.a_min, self.a_max) + + @_make_activation_module(log_softmax) class LogSoftmax(Module): r"""Applies the Log Softmax function.