mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-19 23:51:14 +08:00
feat: add softsign, softmax, hardswish, logsoftmax activation function (#309)
* feat: add softsign activation function * run pre-commit * Add Softsign activation function * Add Softsign activation function * Add documentation for ReLU6, Softplus, and Softsign activations * Update activation functions in neural network layers * Add LogSoftmax and Hardswish activations * run pre-commit * Update activations.py * Added acknowledgements * Fix activation function comments * Fix activation functions in neural network layers
This commit is contained in:
parent
2aedf3e791
commit
5ad8fb7268
@ -6,7 +6,8 @@ with a short description of your contribution(s) below. For example:
|
||||
- Jane Smith: Added the `foo` and `bar` ops.
|
||||
|
||||
MLX was developed with contributions from the following individuals:
|
||||
|
||||
|
||||
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions.
|
||||
- Juarez Bochi: Fixed bug in cross attention.
|
||||
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
||||
- Diogo Da Cruz: Added tri, tril, triu and safetensor support
|
||||
|
@ -5,14 +5,18 @@ from mlx.nn.layers.activations import (
|
||||
ELU,
|
||||
GELU,
|
||||
SELU,
|
||||
Hardswish,
|
||||
LeakyReLU,
|
||||
LogSigmoid,
|
||||
LogSoftmax,
|
||||
Mish,
|
||||
PReLU,
|
||||
ReLU,
|
||||
ReLU6,
|
||||
SiLU,
|
||||
Softmax,
|
||||
Softplus,
|
||||
Softsign,
|
||||
Step,
|
||||
Tanh,
|
||||
celu,
|
||||
@ -20,15 +24,19 @@ from mlx.nn.layers.activations import (
|
||||
gelu,
|
||||
gelu_approx,
|
||||
gelu_fast_approx,
|
||||
hardswish,
|
||||
leaky_relu,
|
||||
log_sigmoid,
|
||||
log_softmax,
|
||||
mish,
|
||||
prelu,
|
||||
relu,
|
||||
relu6,
|
||||
selu,
|
||||
silu,
|
||||
softmax,
|
||||
softplus,
|
||||
softsign,
|
||||
step,
|
||||
tanh,
|
||||
)
|
||||
|
@ -25,7 +25,7 @@ def sigmoid(x):
|
||||
|
||||
|
||||
def relu(x):
|
||||
"""Applies the Rectified Linear Unit.
|
||||
r"""Applies the Rectified Linear Unit.
|
||||
|
||||
Simply ``mx.maximum(x, 0)``.
|
||||
"""
|
||||
@ -33,15 +33,23 @@ def relu(x):
|
||||
|
||||
|
||||
def leaky_relu(x, negative_slope=0.01):
|
||||
"""Applies the Leaky Rectified Linear Unit.
|
||||
r"""Applies the Leaky Rectified Linear Unit.
|
||||
|
||||
Simply ``mx.maximum(negative_slope * x, x)``.
|
||||
"""
|
||||
return mx.maximum(negative_slope * x, x)
|
||||
|
||||
|
||||
def log_softmax(x, axis=-1):
|
||||
r"""Applies the Log Softmax function.
|
||||
|
||||
Applies :math:`x + \log \sum_i e^{x_i}` element wise.
|
||||
"""
|
||||
return x - mx.logsumexp(x, axis=axis, keepdims=True)
|
||||
|
||||
|
||||
def elu(x, alpha=1.0):
|
||||
"""Applies the Exponential Linear Unit.
|
||||
r"""Applies the Exponential Linear Unit.
|
||||
|
||||
Simply ``mx.where(x > 0, x, alpha * (mx.exp(x) - 1))``.
|
||||
"""
|
||||
@ -56,6 +64,14 @@ def relu6(x):
|
||||
return mx.minimum(mx.maximum(x, 0), 6.0)
|
||||
|
||||
|
||||
def softmax(x, axis=-1):
|
||||
r"""Applies the Softmax function.
|
||||
|
||||
Applies :math:`\frac{e^{x_i}}{\sum_j e^{x_j}}` element wise.
|
||||
"""
|
||||
return mx.softmax(x, axis=axis)
|
||||
|
||||
|
||||
def softplus(x):
|
||||
r"""Applies the Softplus function.
|
||||
|
||||
@ -64,6 +80,14 @@ def softplus(x):
|
||||
return mx.logaddexp(x, 0)
|
||||
|
||||
|
||||
def softsign(x):
|
||||
r"""Applies the Softsign function.
|
||||
|
||||
Applies :math:`\frac{x}{1 + |x|}` element wise.
|
||||
"""
|
||||
return mx.divide(x, 1 + mx.abs(x))
|
||||
|
||||
|
||||
def celu(x, alpha=1.0):
|
||||
r"""Applies the Continuously Differentiable Exponential Linear Unit.
|
||||
|
||||
@ -140,6 +164,11 @@ def gelu_fast_approx(x):
|
||||
|
||||
@_make_activation_module
|
||||
class Sigmoid(Module):
|
||||
r"""Applies the sigmoid function, element-wise.
|
||||
|
||||
.. math::
|
||||
\text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@ -202,13 +231,36 @@ def mish(x: mx.array) -> mx.array:
|
||||
return x * mx.tanh(softplus(x))
|
||||
|
||||
|
||||
def hardswish(x):
|
||||
r"""Applies the hardswish function, element-wise.
|
||||
|
||||
.. math::
|
||||
\text{Hardswish}(x) = x * \min(\max(x + 3, 0), 6) / 6
|
||||
"""
|
||||
max_x_3 = mx.maximum(x + 3, 0)
|
||||
return x * mx.minimum(max_x_3, 6) / 6
|
||||
|
||||
|
||||
@_make_activation_module(mish)
|
||||
class Mish(Module):
|
||||
r"""Applies the Mish function, element-wise.
|
||||
|
||||
Reference: https://arxiv.org/abs/1908.08681
|
||||
|
||||
.. math::
|
||||
\text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@_make_activation_module(relu)
|
||||
class ReLU(Module):
|
||||
r"""Applies the Rectified Linear Unit.
|
||||
Simply ``mx.maximum(x, 0)``.
|
||||
|
||||
See :func:`relu`, for the functional equivalent.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@ -249,11 +301,37 @@ class ELU(Module):
|
||||
|
||||
@_make_activation_module(relu6)
|
||||
class ReLU6(Module):
|
||||
r"""Applies the Rectified Linear Unit 6.
|
||||
|
||||
See :func:`relu6`, for the functional equivalent.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@_make_activation_module(softmax)
|
||||
class Softmax(Module):
|
||||
r"""Applies the Softmax function.
|
||||
|
||||
See :func:`softmax`, for the functional equivalent.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@_make_activation_module(softplus)
|
||||
class Softplus(Module):
|
||||
r"""Applies the Softplus function.
|
||||
|
||||
See :func:`softplus`, for the functional equivalent.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@_make_activation_module(softsign)
|
||||
class Softsign(Module):
|
||||
r"""Applies the Softsign function.
|
||||
|
||||
See :func:`softsign`, for the functional equivalent.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@ -278,15 +356,43 @@ class CELU(Module):
|
||||
|
||||
@_make_activation_module(silu)
|
||||
class SiLU(Module):
|
||||
r"""Applies the Sigmoid Linear Unit. Also known as Swish.
|
||||
|
||||
See :func:`silu`, for the functional equivalent.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@_make_activation_module(log_softmax)
|
||||
class LogSoftmax(Module):
|
||||
r"""Applies the Log Softmax function.
|
||||
|
||||
See :func:`log_softmax`, for the functional equivalent.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@_make_activation_module(log_sigmoid)
|
||||
class LogSigmoid(Module):
|
||||
r"""Applies the Log Sigmoid function.
|
||||
|
||||
See :func:`log_sigmoid`, for the functional equivalent.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class PReLU(Module):
|
||||
r"""Applies the element-wise parametric ReLU.
|
||||
Applies :math:`\max(0, x) + a * \min(0, x)` element wise, where :math:`a`
|
||||
is an array.
|
||||
|
||||
See :func:`prelu`, for the functional equivalent.
|
||||
|
||||
Args:
|
||||
num_parameters: number of :math:`a` to learn. Default: 1
|
||||
init: the initial value of :math:`a`. Default: 0.25
|
||||
"""
|
||||
|
||||
def __init__(self, num_parameters=1, init=0.25):
|
||||
super().__init__()
|
||||
self.weight = mx.full([num_parameters], init)
|
||||
@ -346,6 +452,19 @@ def tanh(x):
|
||||
|
||||
@_make_activation_module(tanh)
|
||||
class Tanh(Module):
|
||||
r"""Applies the hyperbolic tangent function.
|
||||
|
||||
See :func:`tanh`, for the functional equivalent.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@_make_activation_module(hardswish)
|
||||
class Hardswish(Module):
|
||||
r"""Applies the hardswish function, element-wise.
|
||||
|
||||
See :func:`hardswish`, for the functional equivalent.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@ -375,4 +494,8 @@ class Step(Module):
|
||||
|
||||
@_make_activation_module(selu)
|
||||
class SELU(Module):
|
||||
r"""Applies the Scaled Exponential Linear Unit.
|
||||
|
||||
See :func:`selu`, for the functional equivalent.
|
||||
"""
|
||||
pass
|
||||
|
@ -667,6 +667,15 @@ class TestNN(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(y.shape, [5])
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_softmax(self):
|
||||
x = mx.array([1.0, -1.0, 0.0])
|
||||
y = nn.softmax(x)
|
||||
epsilon = 1e-4
|
||||
expected_y = mx.array([0.6652, 0.0900, 0.2447])
|
||||
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||
self.assertEqual(y.shape, [3])
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_softplus(self):
|
||||
x = mx.array([1.0, -1.0, 0.0])
|
||||
y = nn.softplus(x)
|
||||
@ -676,6 +685,15 @@ class TestNN(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(y.shape, [3])
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_softsign(self):
|
||||
x = mx.array([1.0, -1.0, 0.0])
|
||||
y = nn.softsign(x)
|
||||
epsilon = 1e-4
|
||||
expected_y = mx.array([0.5, -0.5, 0.0])
|
||||
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||
self.assertEqual(y.shape, [3])
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_celu(self):
|
||||
x = mx.array([1.0, -1.0, 0.0])
|
||||
y = nn.celu(x)
|
||||
@ -691,6 +709,15 @@ class TestNN(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(y.shape, [3])
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_log_softmax(self):
|
||||
x = mx.array([1.0, 2.0, 3.0])
|
||||
y = nn.log_softmax(x)
|
||||
epsilon = 1e-4
|
||||
expected_y = mx.array([-2.4076, -1.4076, -0.4076])
|
||||
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||
self.assertEqual(y.shape, [3])
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_log_sigmoid(self):
|
||||
x = mx.array([1.0, -1.0, 0.0])
|
||||
y = nn.log_sigmoid(x)
|
||||
@ -712,6 +739,15 @@ class TestNN(mlx_tests.MLXTestCase):
|
||||
mx.array([0.8651, -0.3034, 0.0000, 0.3752]),
|
||||
)
|
||||
|
||||
def test_hardswish(self):
|
||||
x = mx.array([-3.0, -1.5, 0.0, 1.5, 3.0])
|
||||
y = nn.hardswish(x)
|
||||
epsilon = 1e-4
|
||||
expected_y = mx.array([0.0, -0.375, 0.0, 1.125, 3.0])
|
||||
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||
self.assertEqual(y.shape, [5])
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_rope(self):
|
||||
for kwargs in [{}, {"traditional": False}, {"base": 10000}]:
|
||||
rope = nn.RoPE(4, **kwargs)
|
||||
|
Loading…
Reference in New Issue
Block a user