mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Softshrink mapping + op (#552)
* Added Softshrink mapping + op * formatting * docs + nits in docstring --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
3f7aba8498
commit
d3a9005454
@ -19,5 +19,6 @@ simple functions.
|
||||
prelu
|
||||
relu
|
||||
selu
|
||||
softshrink
|
||||
silu
|
||||
step
|
||||
|
@ -33,5 +33,6 @@ Layers
|
||||
Sequential
|
||||
SiLU
|
||||
SinusoidalPositionalEncoding
|
||||
Softshrink
|
||||
Step
|
||||
Transformer
|
||||
|
@ -18,6 +18,7 @@ from mlx.nn.layers.activations import (
|
||||
SiLU,
|
||||
Softmax,
|
||||
Softplus,
|
||||
Softshrink,
|
||||
Softsign,
|
||||
Step,
|
||||
Tanh,
|
||||
@ -39,6 +40,7 @@ from mlx.nn.layers.activations import (
|
||||
silu,
|
||||
softmax,
|
||||
softplus,
|
||||
softshrink,
|
||||
softsign,
|
||||
step,
|
||||
tanh,
|
||||
|
@ -89,6 +89,19 @@ def softsign(x):
|
||||
return mx.divide(x, 1 + mx.abs(x))
|
||||
|
||||
|
||||
def softshrink(x, lambd: float = 0.5):
|
||||
r"""Applies the Softshrink activation function.
|
||||
|
||||
.. math::
|
||||
\text{softshrink}(x) = \begin{cases}
|
||||
x - \lambda & \text{if } x > \lambda \\
|
||||
x + \lambda & \text{if } x < -\lambda \\
|
||||
0 & \text{otherwise}
|
||||
\end{cases}
|
||||
"""
|
||||
return mx.where(mx.abs(x) > lambd, x - mx.sign(x) * lambd, 0)
|
||||
|
||||
|
||||
def celu(x, alpha=1.0):
|
||||
r"""Applies the Continuously Differentiable Exponential Linear Unit.
|
||||
|
||||
@ -173,7 +186,7 @@ def glu(x: mx.array, axis: int = -1) -> mx.array:
|
||||
textrm{GLU}(x) = a * \sigma(b)
|
||||
|
||||
Args:
|
||||
axis (int): The dimension to split along. Default: ``-1``.
|
||||
axis (int): The dimension to split along. Default: ``-1``
|
||||
"""
|
||||
a, b = mx.split(x, indices_or_sections=2, axis=axis)
|
||||
return a * mx.sigmoid(b)
|
||||
@ -189,7 +202,7 @@ class GLU(Module):
|
||||
textrm{GLU}(x) = a * \sigma(b)
|
||||
|
||||
Args:
|
||||
axis (int): The dimension to split along. Default: ``-1``.
|
||||
axis (int): The dimension to split along. Default: ``-1``
|
||||
"""
|
||||
|
||||
def __init__(self, axis: int = -1):
|
||||
@ -295,7 +308,7 @@ class ReLU(Module):
|
||||
r"""Applies the Rectified Linear Unit.
|
||||
Simply ``mx.maximum(x, 0)``.
|
||||
|
||||
See :func:`relu`, for the functional equivalent.
|
||||
See :func:`relu` for the functional equivalent.
|
||||
"""
|
||||
|
||||
|
||||
@ -305,7 +318,7 @@ class LeakyReLU(Module):
|
||||
Simply ``mx.maximum(negative_slope * x, x)``.
|
||||
|
||||
Args:
|
||||
negative_slope: Controls the angle of the negative slope. Default: 1e-2.
|
||||
negative_slope: Controls the angle of the negative slope. Default: ``1e-2``
|
||||
"""
|
||||
|
||||
def __init__(self, negative_slope=1e-2):
|
||||
@ -320,10 +333,10 @@ class ELU(Module):
|
||||
r"""Applies the Exponential Linear Unit.
|
||||
Simply ``mx.where(x > 0, x, alpha * (mx.exp(x) - 1))``.
|
||||
|
||||
See :func:`elu`, for the functional equivalent.
|
||||
See :func:`elu` for the functional equivalent.
|
||||
|
||||
Args:
|
||||
alpha: the :math:`\alpha` value for the ELU formulation. Default: 1.0
|
||||
alpha: the :math:`\alpha` value for the ELU formulation. Default: ``1.0``
|
||||
"""
|
||||
|
||||
def __init__(self, alpha=1.0):
|
||||
@ -338,7 +351,7 @@ class ELU(Module):
|
||||
class ReLU6(Module):
|
||||
r"""Applies the Rectified Linear Unit 6.
|
||||
|
||||
See :func:`relu6`, for the functional equivalent.
|
||||
See :func:`relu6` for the functional equivalent.
|
||||
"""
|
||||
|
||||
|
||||
@ -346,7 +359,7 @@ class ReLU6(Module):
|
||||
class Softmax(Module):
|
||||
r"""Applies the Softmax function.
|
||||
|
||||
See :func:`softmax`, for the functional equivalent.
|
||||
See :func:`softmax` for the functional equivalent.
|
||||
"""
|
||||
|
||||
|
||||
@ -354,7 +367,7 @@ class Softmax(Module):
|
||||
class Softplus(Module):
|
||||
r"""Applies the Softplus function.
|
||||
|
||||
See :func:`softplus`, for the functional equivalent.
|
||||
See :func:`softplus` for the functional equivalent.
|
||||
"""
|
||||
|
||||
|
||||
@ -362,19 +375,36 @@ class Softplus(Module):
|
||||
class Softsign(Module):
|
||||
r"""Applies the Softsign function.
|
||||
|
||||
See :func:`softsign`, for the functional equivalent.
|
||||
See :func:`softsign` for the functional equivalent.
|
||||
"""
|
||||
|
||||
|
||||
class Softshrink(Module):
|
||||
r"""Applies the Softshrink function.
|
||||
|
||||
See :func:`softshrink` for the functional equivalent.
|
||||
|
||||
Args:
|
||||
lambd: the :math:`\lambda` value for Softshrink. Default: ``0.5``
|
||||
"""
|
||||
|
||||
def __init__(self, lambd=0.5):
|
||||
super().__init__()
|
||||
self.lambd = lambd
|
||||
|
||||
def __call__(self, x):
|
||||
return softshrink(x, self.lambd)
|
||||
|
||||
|
||||
class CELU(Module):
|
||||
r"""Applies the Continuously Differentiable Exponential Linear Unit.
|
||||
Applies :math:`\max(0, x) + \min(0, \alpha * (\exp(x / \alpha) - 1))`
|
||||
element wise.
|
||||
|
||||
See :func:`celu`, for the functional equivalent.
|
||||
See :func:`celu` for the functional equivalent.
|
||||
|
||||
Args:
|
||||
alpha: the :math:`\alpha` value for the CELU formulation. Default: 1.0
|
||||
alpha: the :math:`\alpha` value for the CELU formulation. Default: ``1.0``
|
||||
"""
|
||||
|
||||
def __init__(self, alpha=1.0):
|
||||
@ -389,7 +419,7 @@ class CELU(Module):
|
||||
class SiLU(Module):
|
||||
r"""Applies the Sigmoid Linear Unit. Also known as Swish.
|
||||
|
||||
See :func:`silu`, for the functional equivalent.
|
||||
See :func:`silu` for the functional equivalent.
|
||||
"""
|
||||
|
||||
|
||||
@ -397,7 +427,7 @@ class SiLU(Module):
|
||||
class LogSoftmax(Module):
|
||||
r"""Applies the Log Softmax function.
|
||||
|
||||
See :func:`log_softmax`, for the functional equivalent.
|
||||
See :func:`log_softmax` for the functional equivalent.
|
||||
"""
|
||||
|
||||
|
||||
@ -405,7 +435,7 @@ class LogSoftmax(Module):
|
||||
class LogSigmoid(Module):
|
||||
r"""Applies the Log Sigmoid function.
|
||||
|
||||
See :func:`log_sigmoid`, for the functional equivalent.
|
||||
See :func:`log_sigmoid` for the functional equivalent.
|
||||
"""
|
||||
|
||||
|
||||
@ -414,11 +444,11 @@ class PReLU(Module):
|
||||
Applies :math:`\max(0, x) + a * \min(0, x)` element wise, where :math:`a`
|
||||
is an array.
|
||||
|
||||
See :func:`prelu`, for the functional equivalent.
|
||||
See :func:`prelu` for the functional equivalent.
|
||||
|
||||
Args:
|
||||
num_parameters: number of :math:`a` to learn. Default: 1
|
||||
init: the initial value of :math:`a`. Default: 0.25
|
||||
num_parameters: number of :math:`a` to learn. Default: ``1``
|
||||
init: the initial value of :math:`a`. Default: ``0.25``
|
||||
"""
|
||||
|
||||
def __init__(self, num_parameters=1, init=0.25):
|
||||
@ -482,7 +512,7 @@ def tanh(x):
|
||||
class Tanh(Module):
|
||||
r"""Applies the hyperbolic tangent function.
|
||||
|
||||
See :func:`tanh`, for the functional equivalent.
|
||||
See :func:`tanh` for the functional equivalent.
|
||||
"""
|
||||
|
||||
|
||||
@ -490,7 +520,7 @@ class Tanh(Module):
|
||||
class Hardswish(Module):
|
||||
r"""Applies the hardswish function, element-wise.
|
||||
|
||||
See :func:`hardswish`, for the functional equivalent.
|
||||
See :func:`hardswish` for the functional equivalent.
|
||||
"""
|
||||
|
||||
|
||||
@ -522,5 +552,5 @@ class Step(Module):
|
||||
class SELU(Module):
|
||||
r"""Applies the Scaled Exponential Linear Unit.
|
||||
|
||||
See :func:`selu`, for the functional equivalent.
|
||||
See :func:`selu` for the functional equivalent.
|
||||
"""
|
||||
|
@ -751,6 +751,21 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(y.shape, [3])
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_softshrink(self):
|
||||
x = mx.array([1.0, -1.0, 0.0])
|
||||
y = nn.softshrink(x)
|
||||
epsilon = 1e-4
|
||||
expected_y = mx.array([0.5, -0.5, 0.0])
|
||||
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||
self.assertEqual(y.shape, [3])
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
y = nn.Softshrink(lambd=0.7)(x)
|
||||
expected_y = mx.array([0.3, -0.3, 0.0])
|
||||
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||
self.assertEqual(y.shape, [3])
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_celu(self):
|
||||
x = mx.array([1.0, -1.0, 0.0])
|
||||
y = nn.celu(x)
|
||||
|
Loading…
Reference in New Issue
Block a user