initial commit

This commit is contained in:
Goekdeniz-Guelmez 2025-06-16 22:33:24 +02:00
parent bc53f8293f
commit 60cd4a5a6f

View File

@ -132,6 +132,20 @@ def silu(x):
return x * mx.sigmoid(x) return x * mx.sigmoid(x)
@partial(mx.compile, shapeless=True)
def clipped_silu(x, a_min=-100, a_max=100):
r"""Applies the Clipped Sigmoid Linear Unit.
Applies :math:`\text{clip}(x \sigma(x), a\_min, a\_max)` element wise, where
:math:`\sigma(\cdot)` is the logistic sigmoid.
Args:
a_min: minimum value for clipping. Default: ``-100``
a_max: maximum value for clipping. Default: ``100``
"""
return mx.clip(x * mx.sigmoid(x), a_min=a_min, a_max=a_max)
@partial(mx.compile, shapeless=True) @partial(mx.compile, shapeless=True)
def log_sigmoid(x): def log_sigmoid(x):
r"""Applies the Log Sigmoid function. r"""Applies the Log Sigmoid function.
@ -488,6 +502,25 @@ class SiLU(Module):
""" """
class ClippedSiLU(Module):
r"""Applies the Clipped Sigmoid Linear Unit.
See :func:`clipped_silu` for the functional equivalent.
Args:
a_min: minimum value for clipping. Default: ``-100``
a_max: maximum value for clipping. Default: ``100``
"""
def __init__(self, a_min=-100, a_max=100):
super().__init__()
self.a_min = a_min
self.a_max = a_max
def __call__(self, x):
return clipped_silu(x, self.a_min, self.a_max)
@_make_activation_module(log_softmax) @_make_activation_module(log_softmax)
class LogSoftmax(Module): class LogSoftmax(Module):
r"""Applies the Log Softmax function. r"""Applies the Log Softmax function.