mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-08 18:16:41 +08:00
initial commit
This commit is contained in:
parent
bc53f8293f
commit
60cd4a5a6f
@ -132,6 +132,20 @@ def silu(x):
|
||||
return x * mx.sigmoid(x)
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def clipped_silu(x, a_min=-100, a_max=100):
|
||||
r"""Applies the Clipped Sigmoid Linear Unit.
|
||||
|
||||
Applies :math:`\text{clip}(x \sigma(x), a\_min, a\_max)` element wise, where
|
||||
:math:`\sigma(\cdot)` is the logistic sigmoid.
|
||||
|
||||
Args:
|
||||
a_min: minimum value for clipping. Default: ``-100``
|
||||
a_max: maximum value for clipping. Default: ``100``
|
||||
"""
|
||||
return mx.clip(x * mx.sigmoid(x), a_min=a_min, a_max=a_max)
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def log_sigmoid(x):
|
||||
r"""Applies the Log Sigmoid function.
|
||||
@ -488,6 +502,25 @@ class SiLU(Module):
|
||||
"""
|
||||
|
||||
|
||||
class ClippedSiLU(Module):
|
||||
r"""Applies the Clipped Sigmoid Linear Unit.
|
||||
|
||||
See :func:`clipped_silu` for the functional equivalent.
|
||||
|
||||
Args:
|
||||
a_min: minimum value for clipping. Default: ``-100``
|
||||
a_max: maximum value for clipping. Default: ``100``
|
||||
"""
|
||||
|
||||
def __init__(self, a_min=-100, a_max=100):
|
||||
super().__init__()
|
||||
self.a_min = a_min
|
||||
self.a_max = a_max
|
||||
|
||||
def __call__(self, x):
|
||||
return clipped_silu(x, self.a_min, self.a_max)
|
||||
|
||||
|
||||
@_make_activation_module(log_softmax)
|
||||
class LogSoftmax(Module):
|
||||
r"""Applies the Log Softmax function.
|
||||
|
Loading…
Reference in New Issue
Block a user