diff --git a/docs/src/python/nn/functions.rst b/docs/src/python/nn/functions.rst index fd4302ef1..fc99dcad1 100644 --- a/docs/src/python/nn/functions.rst +++ b/docs/src/python/nn/functions.rst @@ -19,5 +19,6 @@ simple functions. prelu relu selu + softshrink silu step diff --git a/docs/src/python/nn/layers.rst b/docs/src/python/nn/layers.rst index fc8848c54..ef099ef2f 100644 --- a/docs/src/python/nn/layers.rst +++ b/docs/src/python/nn/layers.rst @@ -33,5 +33,6 @@ Layers Sequential SiLU SinusoidalPositionalEncoding + Softshrink Step Transformer diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index 00832f525..f5092418b 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -18,6 +18,7 @@ from mlx.nn.layers.activations import ( SiLU, Softmax, Softplus, + Softshrink, Softsign, Step, Tanh, @@ -39,6 +40,7 @@ from mlx.nn.layers.activations import ( silu, softmax, softplus, + softshrink, softsign, step, tanh, diff --git a/python/mlx/nn/layers/activations.py b/python/mlx/nn/layers/activations.py index 87d2fbac6..db07ce190 100644 --- a/python/mlx/nn/layers/activations.py +++ b/python/mlx/nn/layers/activations.py @@ -89,6 +89,19 @@ def softsign(x): return mx.divide(x, 1 + mx.abs(x)) +def softshrink(x, lambd: float = 0.5): + r"""Applies the Softshrink activation function. + + .. math:: + \text{softshrink}(x) = \begin{cases} + x - \lambda & \text{if } x > \lambda \\ + x + \lambda & \text{if } x < -\lambda \\ + 0 & \text{otherwise} + \end{cases} + """ + return mx.where(mx.abs(x) > lambd, x - mx.sign(x) * lambd, 0) + + def celu(x, alpha=1.0): r"""Applies the Continuously Differentiable Exponential Linear Unit. @@ -173,7 +186,7 @@ def glu(x: mx.array, axis: int = -1) -> mx.array: textrm{GLU}(x) = a * \sigma(b) Args: - axis (int): The dimension to split along. Default: ``-1``. + axis (int): The dimension to split along. Default: ``-1`` """ a, b = mx.split(x, indices_or_sections=2, axis=axis) return a * mx.sigmoid(b) @@ -189,7 +202,7 @@ class GLU(Module): textrm{GLU}(x) = a * \sigma(b) Args: - axis (int): The dimension to split along. Default: ``-1``. + axis (int): The dimension to split along. Default: ``-1`` """ def __init__(self, axis: int = -1): @@ -295,7 +308,7 @@ class ReLU(Module): r"""Applies the Rectified Linear Unit. Simply ``mx.maximum(x, 0)``. - See :func:`relu`, for the functional equivalent. + See :func:`relu` for the functional equivalent. """ @@ -305,7 +318,7 @@ class LeakyReLU(Module): Simply ``mx.maximum(negative_slope * x, x)``. Args: - negative_slope: Controls the angle of the negative slope. Default: 1e-2. + negative_slope: Controls the angle of the negative slope. Default: ``1e-2`` """ def __init__(self, negative_slope=1e-2): @@ -320,10 +333,10 @@ class ELU(Module): r"""Applies the Exponential Linear Unit. Simply ``mx.where(x > 0, x, alpha * (mx.exp(x) - 1))``. - See :func:`elu`, for the functional equivalent. + See :func:`elu` for the functional equivalent. Args: - alpha: the :math:`\alpha` value for the ELU formulation. Default: 1.0 + alpha: the :math:`\alpha` value for the ELU formulation. Default: ``1.0`` """ def __init__(self, alpha=1.0): @@ -338,7 +351,7 @@ class ELU(Module): class ReLU6(Module): r"""Applies the Rectified Linear Unit 6. - See :func:`relu6`, for the functional equivalent. + See :func:`relu6` for the functional equivalent. """ @@ -346,7 +359,7 @@ class ReLU6(Module): class Softmax(Module): r"""Applies the Softmax function. - See :func:`softmax`, for the functional equivalent. + See :func:`softmax` for the functional equivalent. """ @@ -354,7 +367,7 @@ class Softmax(Module): class Softplus(Module): r"""Applies the Softplus function. - See :func:`softplus`, for the functional equivalent. + See :func:`softplus` for the functional equivalent. """ @@ -362,19 +375,36 @@ class Softplus(Module): class Softsign(Module): r"""Applies the Softsign function. - See :func:`softsign`, for the functional equivalent. + See :func:`softsign` for the functional equivalent. """ +class Softshrink(Module): + r"""Applies the Softshrink function. + + See :func:`softshrink` for the functional equivalent. + + Args: + lambd: the :math:`\lambda` value for Softshrink. Default: ``0.5`` + """ + + def __init__(self, lambd=0.5): + super().__init__() + self.lambd = lambd + + def __call__(self, x): + return softshrink(x, self.lambd) + + class CELU(Module): r"""Applies the Continuously Differentiable Exponential Linear Unit. Applies :math:`\max(0, x) + \min(0, \alpha * (\exp(x / \alpha) - 1))` element wise. - See :func:`celu`, for the functional equivalent. + See :func:`celu` for the functional equivalent. Args: - alpha: the :math:`\alpha` value for the CELU formulation. Default: 1.0 + alpha: the :math:`\alpha` value for the CELU formulation. Default: ``1.0`` """ def __init__(self, alpha=1.0): @@ -389,7 +419,7 @@ class CELU(Module): class SiLU(Module): r"""Applies the Sigmoid Linear Unit. Also known as Swish. - See :func:`silu`, for the functional equivalent. + See :func:`silu` for the functional equivalent. """ @@ -397,7 +427,7 @@ class SiLU(Module): class LogSoftmax(Module): r"""Applies the Log Softmax function. - See :func:`log_softmax`, for the functional equivalent. + See :func:`log_softmax` for the functional equivalent. """ @@ -405,7 +435,7 @@ class LogSoftmax(Module): class LogSigmoid(Module): r"""Applies the Log Sigmoid function. - See :func:`log_sigmoid`, for the functional equivalent. + See :func:`log_sigmoid` for the functional equivalent. """ @@ -414,11 +444,11 @@ class PReLU(Module): Applies :math:`\max(0, x) + a * \min(0, x)` element wise, where :math:`a` is an array. - See :func:`prelu`, for the functional equivalent. + See :func:`prelu` for the functional equivalent. Args: - num_parameters: number of :math:`a` to learn. Default: 1 - init: the initial value of :math:`a`. Default: 0.25 + num_parameters: number of :math:`a` to learn. Default: ``1`` + init: the initial value of :math:`a`. Default: ``0.25`` """ def __init__(self, num_parameters=1, init=0.25): @@ -482,7 +512,7 @@ def tanh(x): class Tanh(Module): r"""Applies the hyperbolic tangent function. - See :func:`tanh`, for the functional equivalent. + See :func:`tanh` for the functional equivalent. """ @@ -490,7 +520,7 @@ class Tanh(Module): class Hardswish(Module): r"""Applies the hardswish function, element-wise. - See :func:`hardswish`, for the functional equivalent. + See :func:`hardswish` for the functional equivalent. """ @@ -522,5 +552,5 @@ class Step(Module): class SELU(Module): r"""Applies the Scaled Exponential Linear Unit. - See :func:`selu`, for the functional equivalent. + See :func:`selu` for the functional equivalent. """ diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 2893af8e7..6d65e19bd 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -751,6 +751,21 @@ class TestLayers(mlx_tests.MLXTestCase): self.assertEqual(y.shape, [3]) self.assertEqual(y.dtype, mx.float32) + def test_softshrink(self): + x = mx.array([1.0, -1.0, 0.0]) + y = nn.softshrink(x) + epsilon = 1e-4 + expected_y = mx.array([0.5, -0.5, 0.0]) + self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon)) + self.assertEqual(y.shape, [3]) + self.assertEqual(y.dtype, mx.float32) + + y = nn.Softshrink(lambd=0.7)(x) + expected_y = mx.array([0.3, -0.3, 0.0]) + self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon)) + self.assertEqual(y.shape, [3]) + self.assertEqual(y.dtype, mx.float32) + def test_celu(self): x = mx.array([1.0, -1.0, 0.0]) y = nn.celu(x)