mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Softshrink mapping + op (#552)
* Added Softshrink mapping + op * formatting * docs + nits in docstring --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
3f7aba8498
commit
d3a9005454
@ -19,5 +19,6 @@ simple functions.
|
|||||||
prelu
|
prelu
|
||||||
relu
|
relu
|
||||||
selu
|
selu
|
||||||
|
softshrink
|
||||||
silu
|
silu
|
||||||
step
|
step
|
||||||
|
@ -33,5 +33,6 @@ Layers
|
|||||||
Sequential
|
Sequential
|
||||||
SiLU
|
SiLU
|
||||||
SinusoidalPositionalEncoding
|
SinusoidalPositionalEncoding
|
||||||
|
Softshrink
|
||||||
Step
|
Step
|
||||||
Transformer
|
Transformer
|
||||||
|
@ -18,6 +18,7 @@ from mlx.nn.layers.activations import (
|
|||||||
SiLU,
|
SiLU,
|
||||||
Softmax,
|
Softmax,
|
||||||
Softplus,
|
Softplus,
|
||||||
|
Softshrink,
|
||||||
Softsign,
|
Softsign,
|
||||||
Step,
|
Step,
|
||||||
Tanh,
|
Tanh,
|
||||||
@ -39,6 +40,7 @@ from mlx.nn.layers.activations import (
|
|||||||
silu,
|
silu,
|
||||||
softmax,
|
softmax,
|
||||||
softplus,
|
softplus,
|
||||||
|
softshrink,
|
||||||
softsign,
|
softsign,
|
||||||
step,
|
step,
|
||||||
tanh,
|
tanh,
|
||||||
|
@ -89,6 +89,19 @@ def softsign(x):
|
|||||||
return mx.divide(x, 1 + mx.abs(x))
|
return mx.divide(x, 1 + mx.abs(x))
|
||||||
|
|
||||||
|
|
||||||
|
def softshrink(x, lambd: float = 0.5):
|
||||||
|
r"""Applies the Softshrink activation function.
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\text{softshrink}(x) = \begin{cases}
|
||||||
|
x - \lambda & \text{if } x > \lambda \\
|
||||||
|
x + \lambda & \text{if } x < -\lambda \\
|
||||||
|
0 & \text{otherwise}
|
||||||
|
\end{cases}
|
||||||
|
"""
|
||||||
|
return mx.where(mx.abs(x) > lambd, x - mx.sign(x) * lambd, 0)
|
||||||
|
|
||||||
|
|
||||||
def celu(x, alpha=1.0):
|
def celu(x, alpha=1.0):
|
||||||
r"""Applies the Continuously Differentiable Exponential Linear Unit.
|
r"""Applies the Continuously Differentiable Exponential Linear Unit.
|
||||||
|
|
||||||
@ -173,7 +186,7 @@ def glu(x: mx.array, axis: int = -1) -> mx.array:
|
|||||||
textrm{GLU}(x) = a * \sigma(b)
|
textrm{GLU}(x) = a * \sigma(b)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
axis (int): The dimension to split along. Default: ``-1``.
|
axis (int): The dimension to split along. Default: ``-1``
|
||||||
"""
|
"""
|
||||||
a, b = mx.split(x, indices_or_sections=2, axis=axis)
|
a, b = mx.split(x, indices_or_sections=2, axis=axis)
|
||||||
return a * mx.sigmoid(b)
|
return a * mx.sigmoid(b)
|
||||||
@ -189,7 +202,7 @@ class GLU(Module):
|
|||||||
textrm{GLU}(x) = a * \sigma(b)
|
textrm{GLU}(x) = a * \sigma(b)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
axis (int): The dimension to split along. Default: ``-1``.
|
axis (int): The dimension to split along. Default: ``-1``
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, axis: int = -1):
|
def __init__(self, axis: int = -1):
|
||||||
@ -295,7 +308,7 @@ class ReLU(Module):
|
|||||||
r"""Applies the Rectified Linear Unit.
|
r"""Applies the Rectified Linear Unit.
|
||||||
Simply ``mx.maximum(x, 0)``.
|
Simply ``mx.maximum(x, 0)``.
|
||||||
|
|
||||||
See :func:`relu`, for the functional equivalent.
|
See :func:`relu` for the functional equivalent.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -305,7 +318,7 @@ class LeakyReLU(Module):
|
|||||||
Simply ``mx.maximum(negative_slope * x, x)``.
|
Simply ``mx.maximum(negative_slope * x, x)``.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
negative_slope: Controls the angle of the negative slope. Default: 1e-2.
|
negative_slope: Controls the angle of the negative slope. Default: ``1e-2``
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, negative_slope=1e-2):
|
def __init__(self, negative_slope=1e-2):
|
||||||
@ -320,10 +333,10 @@ class ELU(Module):
|
|||||||
r"""Applies the Exponential Linear Unit.
|
r"""Applies the Exponential Linear Unit.
|
||||||
Simply ``mx.where(x > 0, x, alpha * (mx.exp(x) - 1))``.
|
Simply ``mx.where(x > 0, x, alpha * (mx.exp(x) - 1))``.
|
||||||
|
|
||||||
See :func:`elu`, for the functional equivalent.
|
See :func:`elu` for the functional equivalent.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
alpha: the :math:`\alpha` value for the ELU formulation. Default: 1.0
|
alpha: the :math:`\alpha` value for the ELU formulation. Default: ``1.0``
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, alpha=1.0):
|
def __init__(self, alpha=1.0):
|
||||||
@ -338,7 +351,7 @@ class ELU(Module):
|
|||||||
class ReLU6(Module):
|
class ReLU6(Module):
|
||||||
r"""Applies the Rectified Linear Unit 6.
|
r"""Applies the Rectified Linear Unit 6.
|
||||||
|
|
||||||
See :func:`relu6`, for the functional equivalent.
|
See :func:`relu6` for the functional equivalent.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -346,7 +359,7 @@ class ReLU6(Module):
|
|||||||
class Softmax(Module):
|
class Softmax(Module):
|
||||||
r"""Applies the Softmax function.
|
r"""Applies the Softmax function.
|
||||||
|
|
||||||
See :func:`softmax`, for the functional equivalent.
|
See :func:`softmax` for the functional equivalent.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -354,7 +367,7 @@ class Softmax(Module):
|
|||||||
class Softplus(Module):
|
class Softplus(Module):
|
||||||
r"""Applies the Softplus function.
|
r"""Applies the Softplus function.
|
||||||
|
|
||||||
See :func:`softplus`, for the functional equivalent.
|
See :func:`softplus` for the functional equivalent.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -362,19 +375,36 @@ class Softplus(Module):
|
|||||||
class Softsign(Module):
|
class Softsign(Module):
|
||||||
r"""Applies the Softsign function.
|
r"""Applies the Softsign function.
|
||||||
|
|
||||||
See :func:`softsign`, for the functional equivalent.
|
See :func:`softsign` for the functional equivalent.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class Softshrink(Module):
|
||||||
|
r"""Applies the Softshrink function.
|
||||||
|
|
||||||
|
See :func:`softshrink` for the functional equivalent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lambd: the :math:`\lambda` value for Softshrink. Default: ``0.5``
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, lambd=0.5):
|
||||||
|
super().__init__()
|
||||||
|
self.lambd = lambd
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return softshrink(x, self.lambd)
|
||||||
|
|
||||||
|
|
||||||
class CELU(Module):
|
class CELU(Module):
|
||||||
r"""Applies the Continuously Differentiable Exponential Linear Unit.
|
r"""Applies the Continuously Differentiable Exponential Linear Unit.
|
||||||
Applies :math:`\max(0, x) + \min(0, \alpha * (\exp(x / \alpha) - 1))`
|
Applies :math:`\max(0, x) + \min(0, \alpha * (\exp(x / \alpha) - 1))`
|
||||||
element wise.
|
element wise.
|
||||||
|
|
||||||
See :func:`celu`, for the functional equivalent.
|
See :func:`celu` for the functional equivalent.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
alpha: the :math:`\alpha` value for the CELU formulation. Default: 1.0
|
alpha: the :math:`\alpha` value for the CELU formulation. Default: ``1.0``
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, alpha=1.0):
|
def __init__(self, alpha=1.0):
|
||||||
@ -389,7 +419,7 @@ class CELU(Module):
|
|||||||
class SiLU(Module):
|
class SiLU(Module):
|
||||||
r"""Applies the Sigmoid Linear Unit. Also known as Swish.
|
r"""Applies the Sigmoid Linear Unit. Also known as Swish.
|
||||||
|
|
||||||
See :func:`silu`, for the functional equivalent.
|
See :func:`silu` for the functional equivalent.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -397,7 +427,7 @@ class SiLU(Module):
|
|||||||
class LogSoftmax(Module):
|
class LogSoftmax(Module):
|
||||||
r"""Applies the Log Softmax function.
|
r"""Applies the Log Softmax function.
|
||||||
|
|
||||||
See :func:`log_softmax`, for the functional equivalent.
|
See :func:`log_softmax` for the functional equivalent.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -405,7 +435,7 @@ class LogSoftmax(Module):
|
|||||||
class LogSigmoid(Module):
|
class LogSigmoid(Module):
|
||||||
r"""Applies the Log Sigmoid function.
|
r"""Applies the Log Sigmoid function.
|
||||||
|
|
||||||
See :func:`log_sigmoid`, for the functional equivalent.
|
See :func:`log_sigmoid` for the functional equivalent.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -414,11 +444,11 @@ class PReLU(Module):
|
|||||||
Applies :math:`\max(0, x) + a * \min(0, x)` element wise, where :math:`a`
|
Applies :math:`\max(0, x) + a * \min(0, x)` element wise, where :math:`a`
|
||||||
is an array.
|
is an array.
|
||||||
|
|
||||||
See :func:`prelu`, for the functional equivalent.
|
See :func:`prelu` for the functional equivalent.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_parameters: number of :math:`a` to learn. Default: 1
|
num_parameters: number of :math:`a` to learn. Default: ``1``
|
||||||
init: the initial value of :math:`a`. Default: 0.25
|
init: the initial value of :math:`a`. Default: ``0.25``
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, num_parameters=1, init=0.25):
|
def __init__(self, num_parameters=1, init=0.25):
|
||||||
@ -482,7 +512,7 @@ def tanh(x):
|
|||||||
class Tanh(Module):
|
class Tanh(Module):
|
||||||
r"""Applies the hyperbolic tangent function.
|
r"""Applies the hyperbolic tangent function.
|
||||||
|
|
||||||
See :func:`tanh`, for the functional equivalent.
|
See :func:`tanh` for the functional equivalent.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -490,7 +520,7 @@ class Tanh(Module):
|
|||||||
class Hardswish(Module):
|
class Hardswish(Module):
|
||||||
r"""Applies the hardswish function, element-wise.
|
r"""Applies the hardswish function, element-wise.
|
||||||
|
|
||||||
See :func:`hardswish`, for the functional equivalent.
|
See :func:`hardswish` for the functional equivalent.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -522,5 +552,5 @@ class Step(Module):
|
|||||||
class SELU(Module):
|
class SELU(Module):
|
||||||
r"""Applies the Scaled Exponential Linear Unit.
|
r"""Applies the Scaled Exponential Linear Unit.
|
||||||
|
|
||||||
See :func:`selu`, for the functional equivalent.
|
See :func:`selu` for the functional equivalent.
|
||||||
"""
|
"""
|
||||||
|
@ -751,6 +751,21 @@ class TestLayers(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(y.shape, [3])
|
self.assertEqual(y.shape, [3])
|
||||||
self.assertEqual(y.dtype, mx.float32)
|
self.assertEqual(y.dtype, mx.float32)
|
||||||
|
|
||||||
|
def test_softshrink(self):
|
||||||
|
x = mx.array([1.0, -1.0, 0.0])
|
||||||
|
y = nn.softshrink(x)
|
||||||
|
epsilon = 1e-4
|
||||||
|
expected_y = mx.array([0.5, -0.5, 0.0])
|
||||||
|
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||||
|
self.assertEqual(y.shape, [3])
|
||||||
|
self.assertEqual(y.dtype, mx.float32)
|
||||||
|
|
||||||
|
y = nn.Softshrink(lambd=0.7)(x)
|
||||||
|
expected_y = mx.array([0.3, -0.3, 0.0])
|
||||||
|
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||||
|
self.assertEqual(y.shape, [3])
|
||||||
|
self.assertEqual(y.dtype, mx.float32)
|
||||||
|
|
||||||
def test_celu(self):
|
def test_celu(self):
|
||||||
x = mx.array([1.0, -1.0, 0.0])
|
x = mx.array([1.0, -1.0, 0.0])
|
||||||
y = nn.celu(x)
|
y = nn.celu(x)
|
||||||
|
Loading…
Reference in New Issue
Block a user