Softshrink mapping + op (#552)

* Added Softshrink mapping + op

* formatting

* docs + nits in docstring

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Andre Slavescu 2024-01-30 15:56:28 -05:00 committed by GitHub
parent 3f7aba8498
commit d3a9005454
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 70 additions and 21 deletions

View File

@ -19,5 +19,6 @@ simple functions.
prelu
relu
selu
softshrink
silu
step

View File

@ -33,5 +33,6 @@ Layers
Sequential
SiLU
SinusoidalPositionalEncoding
Softshrink
Step
Transformer

View File

@ -18,6 +18,7 @@ from mlx.nn.layers.activations import (
SiLU,
Softmax,
Softplus,
Softshrink,
Softsign,
Step,
Tanh,
@ -39,6 +40,7 @@ from mlx.nn.layers.activations import (
silu,
softmax,
softplus,
softshrink,
softsign,
step,
tanh,

View File

@ -89,6 +89,19 @@ def softsign(x):
return mx.divide(x, 1 + mx.abs(x))
def softshrink(x, lambd: float = 0.5):
r"""Applies the Softshrink activation function.
.. math::
\text{softshrink}(x) = \begin{cases}
x - \lambda & \text{if } x > \lambda \\
x + \lambda & \text{if } x < -\lambda \\
0 & \text{otherwise}
\end{cases}
"""
return mx.where(mx.abs(x) > lambd, x - mx.sign(x) * lambd, 0)
def celu(x, alpha=1.0):
r"""Applies the Continuously Differentiable Exponential Linear Unit.
@ -173,7 +186,7 @@ def glu(x: mx.array, axis: int = -1) -> mx.array:
textrm{GLU}(x) = a * \sigma(b)
Args:
axis (int): The dimension to split along. Default: ``-1``.
axis (int): The dimension to split along. Default: ``-1``
"""
a, b = mx.split(x, indices_or_sections=2, axis=axis)
return a * mx.sigmoid(b)
@ -189,7 +202,7 @@ class GLU(Module):
textrm{GLU}(x) = a * \sigma(b)
Args:
axis (int): The dimension to split along. Default: ``-1``.
axis (int): The dimension to split along. Default: ``-1``
"""
def __init__(self, axis: int = -1):
@ -295,7 +308,7 @@ class ReLU(Module):
r"""Applies the Rectified Linear Unit.
Simply ``mx.maximum(x, 0)``.
See :func:`relu`, for the functional equivalent.
See :func:`relu` for the functional equivalent.
"""
@ -305,7 +318,7 @@ class LeakyReLU(Module):
Simply ``mx.maximum(negative_slope * x, x)``.
Args:
negative_slope: Controls the angle of the negative slope. Default: 1e-2.
negative_slope: Controls the angle of the negative slope. Default: ``1e-2``
"""
def __init__(self, negative_slope=1e-2):
@ -320,10 +333,10 @@ class ELU(Module):
r"""Applies the Exponential Linear Unit.
Simply ``mx.where(x > 0, x, alpha * (mx.exp(x) - 1))``.
See :func:`elu`, for the functional equivalent.
See :func:`elu` for the functional equivalent.
Args:
alpha: the :math:`\alpha` value for the ELU formulation. Default: 1.0
alpha: the :math:`\alpha` value for the ELU formulation. Default: ``1.0``
"""
def __init__(self, alpha=1.0):
@ -338,7 +351,7 @@ class ELU(Module):
class ReLU6(Module):
r"""Applies the Rectified Linear Unit 6.
See :func:`relu6`, for the functional equivalent.
See :func:`relu6` for the functional equivalent.
"""
@ -346,7 +359,7 @@ class ReLU6(Module):
class Softmax(Module):
r"""Applies the Softmax function.
See :func:`softmax`, for the functional equivalent.
See :func:`softmax` for the functional equivalent.
"""
@ -354,7 +367,7 @@ class Softmax(Module):
class Softplus(Module):
r"""Applies the Softplus function.
See :func:`softplus`, for the functional equivalent.
See :func:`softplus` for the functional equivalent.
"""
@ -362,19 +375,36 @@ class Softplus(Module):
class Softsign(Module):
r"""Applies the Softsign function.
See :func:`softsign`, for the functional equivalent.
See :func:`softsign` for the functional equivalent.
"""
class Softshrink(Module):
r"""Applies the Softshrink function.
See :func:`softshrink` for the functional equivalent.
Args:
lambd: the :math:`\lambda` value for Softshrink. Default: ``0.5``
"""
def __init__(self, lambd=0.5):
super().__init__()
self.lambd = lambd
def __call__(self, x):
return softshrink(x, self.lambd)
class CELU(Module):
r"""Applies the Continuously Differentiable Exponential Linear Unit.
Applies :math:`\max(0, x) + \min(0, \alpha * (\exp(x / \alpha) - 1))`
element wise.
See :func:`celu`, for the functional equivalent.
See :func:`celu` for the functional equivalent.
Args:
alpha: the :math:`\alpha` value for the CELU formulation. Default: 1.0
alpha: the :math:`\alpha` value for the CELU formulation. Default: ``1.0``
"""
def __init__(self, alpha=1.0):
@ -389,7 +419,7 @@ class CELU(Module):
class SiLU(Module):
r"""Applies the Sigmoid Linear Unit. Also known as Swish.
See :func:`silu`, for the functional equivalent.
See :func:`silu` for the functional equivalent.
"""
@ -397,7 +427,7 @@ class SiLU(Module):
class LogSoftmax(Module):
r"""Applies the Log Softmax function.
See :func:`log_softmax`, for the functional equivalent.
See :func:`log_softmax` for the functional equivalent.
"""
@ -405,7 +435,7 @@ class LogSoftmax(Module):
class LogSigmoid(Module):
r"""Applies the Log Sigmoid function.
See :func:`log_sigmoid`, for the functional equivalent.
See :func:`log_sigmoid` for the functional equivalent.
"""
@ -414,11 +444,11 @@ class PReLU(Module):
Applies :math:`\max(0, x) + a * \min(0, x)` element wise, where :math:`a`
is an array.
See :func:`prelu`, for the functional equivalent.
See :func:`prelu` for the functional equivalent.
Args:
num_parameters: number of :math:`a` to learn. Default: 1
init: the initial value of :math:`a`. Default: 0.25
num_parameters: number of :math:`a` to learn. Default: ``1``
init: the initial value of :math:`a`. Default: ``0.25``
"""
def __init__(self, num_parameters=1, init=0.25):
@ -482,7 +512,7 @@ def tanh(x):
class Tanh(Module):
r"""Applies the hyperbolic tangent function.
See :func:`tanh`, for the functional equivalent.
See :func:`tanh` for the functional equivalent.
"""
@ -490,7 +520,7 @@ class Tanh(Module):
class Hardswish(Module):
r"""Applies the hardswish function, element-wise.
See :func:`hardswish`, for the functional equivalent.
See :func:`hardswish` for the functional equivalent.
"""
@ -522,5 +552,5 @@ class Step(Module):
class SELU(Module):
r"""Applies the Scaled Exponential Linear Unit.
See :func:`selu`, for the functional equivalent.
See :func:`selu` for the functional equivalent.
"""

View File

@ -751,6 +751,21 @@ class TestLayers(mlx_tests.MLXTestCase):
self.assertEqual(y.shape, [3])
self.assertEqual(y.dtype, mx.float32)
def test_softshrink(self):
x = mx.array([1.0, -1.0, 0.0])
y = nn.softshrink(x)
epsilon = 1e-4
expected_y = mx.array([0.5, -0.5, 0.0])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, [3])
self.assertEqual(y.dtype, mx.float32)
y = nn.Softshrink(lambd=0.7)(x)
expected_y = mx.array([0.3, -0.3, 0.0])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, [3])
self.assertEqual(y.dtype, mx.float32)
def test_celu(self):
x = mx.array([1.0, -1.0, 0.0])
y = nn.celu(x)