mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-21 16:51:15 +08:00
feat: add softsign, softmax, hardswish, logsoftmax activation function (#309)
* feat: add softsign activation function * run pre-commit * Add Softsign activation function * Add Softsign activation function * Add documentation for ReLU6, Softplus, and Softsign activations * Update activation functions in neural network layers * Add LogSoftmax and Hardswish activations * run pre-commit * Update activations.py * Added acknowledgements * Fix activation function comments * Fix activation functions in neural network layers
This commit is contained in:
parent
2aedf3e791
commit
5ad8fb7268
@ -7,6 +7,7 @@ with a short description of your contribution(s) below. For example:
|
|||||||
|
|
||||||
MLX was developed with contributions from the following individuals:
|
MLX was developed with contributions from the following individuals:
|
||||||
|
|
||||||
|
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions.
|
||||||
- Juarez Bochi: Fixed bug in cross attention.
|
- Juarez Bochi: Fixed bug in cross attention.
|
||||||
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
||||||
- Diogo Da Cruz: Added tri, tril, triu and safetensor support
|
- Diogo Da Cruz: Added tri, tril, triu and safetensor support
|
||||||
|
@ -5,14 +5,18 @@ from mlx.nn.layers.activations import (
|
|||||||
ELU,
|
ELU,
|
||||||
GELU,
|
GELU,
|
||||||
SELU,
|
SELU,
|
||||||
|
Hardswish,
|
||||||
LeakyReLU,
|
LeakyReLU,
|
||||||
LogSigmoid,
|
LogSigmoid,
|
||||||
|
LogSoftmax,
|
||||||
Mish,
|
Mish,
|
||||||
PReLU,
|
PReLU,
|
||||||
ReLU,
|
ReLU,
|
||||||
ReLU6,
|
ReLU6,
|
||||||
SiLU,
|
SiLU,
|
||||||
|
Softmax,
|
||||||
Softplus,
|
Softplus,
|
||||||
|
Softsign,
|
||||||
Step,
|
Step,
|
||||||
Tanh,
|
Tanh,
|
||||||
celu,
|
celu,
|
||||||
@ -20,15 +24,19 @@ from mlx.nn.layers.activations import (
|
|||||||
gelu,
|
gelu,
|
||||||
gelu_approx,
|
gelu_approx,
|
||||||
gelu_fast_approx,
|
gelu_fast_approx,
|
||||||
|
hardswish,
|
||||||
leaky_relu,
|
leaky_relu,
|
||||||
log_sigmoid,
|
log_sigmoid,
|
||||||
|
log_softmax,
|
||||||
mish,
|
mish,
|
||||||
prelu,
|
prelu,
|
||||||
relu,
|
relu,
|
||||||
relu6,
|
relu6,
|
||||||
selu,
|
selu,
|
||||||
silu,
|
silu,
|
||||||
|
softmax,
|
||||||
softplus,
|
softplus,
|
||||||
|
softsign,
|
||||||
step,
|
step,
|
||||||
tanh,
|
tanh,
|
||||||
)
|
)
|
||||||
|
@ -25,7 +25,7 @@ def sigmoid(x):
|
|||||||
|
|
||||||
|
|
||||||
def relu(x):
|
def relu(x):
|
||||||
"""Applies the Rectified Linear Unit.
|
r"""Applies the Rectified Linear Unit.
|
||||||
|
|
||||||
Simply ``mx.maximum(x, 0)``.
|
Simply ``mx.maximum(x, 0)``.
|
||||||
"""
|
"""
|
||||||
@ -33,15 +33,23 @@ def relu(x):
|
|||||||
|
|
||||||
|
|
||||||
def leaky_relu(x, negative_slope=0.01):
|
def leaky_relu(x, negative_slope=0.01):
|
||||||
"""Applies the Leaky Rectified Linear Unit.
|
r"""Applies the Leaky Rectified Linear Unit.
|
||||||
|
|
||||||
Simply ``mx.maximum(negative_slope * x, x)``.
|
Simply ``mx.maximum(negative_slope * x, x)``.
|
||||||
"""
|
"""
|
||||||
return mx.maximum(negative_slope * x, x)
|
return mx.maximum(negative_slope * x, x)
|
||||||
|
|
||||||
|
|
||||||
|
def log_softmax(x, axis=-1):
|
||||||
|
r"""Applies the Log Softmax function.
|
||||||
|
|
||||||
|
Applies :math:`x + \log \sum_i e^{x_i}` element wise.
|
||||||
|
"""
|
||||||
|
return x - mx.logsumexp(x, axis=axis, keepdims=True)
|
||||||
|
|
||||||
|
|
||||||
def elu(x, alpha=1.0):
|
def elu(x, alpha=1.0):
|
||||||
"""Applies the Exponential Linear Unit.
|
r"""Applies the Exponential Linear Unit.
|
||||||
|
|
||||||
Simply ``mx.where(x > 0, x, alpha * (mx.exp(x) - 1))``.
|
Simply ``mx.where(x > 0, x, alpha * (mx.exp(x) - 1))``.
|
||||||
"""
|
"""
|
||||||
@ -56,6 +64,14 @@ def relu6(x):
|
|||||||
return mx.minimum(mx.maximum(x, 0), 6.0)
|
return mx.minimum(mx.maximum(x, 0), 6.0)
|
||||||
|
|
||||||
|
|
||||||
|
def softmax(x, axis=-1):
|
||||||
|
r"""Applies the Softmax function.
|
||||||
|
|
||||||
|
Applies :math:`\frac{e^{x_i}}{\sum_j e^{x_j}}` element wise.
|
||||||
|
"""
|
||||||
|
return mx.softmax(x, axis=axis)
|
||||||
|
|
||||||
|
|
||||||
def softplus(x):
|
def softplus(x):
|
||||||
r"""Applies the Softplus function.
|
r"""Applies the Softplus function.
|
||||||
|
|
||||||
@ -64,6 +80,14 @@ def softplus(x):
|
|||||||
return mx.logaddexp(x, 0)
|
return mx.logaddexp(x, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def softsign(x):
|
||||||
|
r"""Applies the Softsign function.
|
||||||
|
|
||||||
|
Applies :math:`\frac{x}{1 + |x|}` element wise.
|
||||||
|
"""
|
||||||
|
return mx.divide(x, 1 + mx.abs(x))
|
||||||
|
|
||||||
|
|
||||||
def celu(x, alpha=1.0):
|
def celu(x, alpha=1.0):
|
||||||
r"""Applies the Continuously Differentiable Exponential Linear Unit.
|
r"""Applies the Continuously Differentiable Exponential Linear Unit.
|
||||||
|
|
||||||
@ -140,6 +164,11 @@ def gelu_fast_approx(x):
|
|||||||
|
|
||||||
@_make_activation_module
|
@_make_activation_module
|
||||||
class Sigmoid(Module):
|
class Sigmoid(Module):
|
||||||
|
r"""Applies the sigmoid function, element-wise.
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -202,13 +231,36 @@ def mish(x: mx.array) -> mx.array:
|
|||||||
return x * mx.tanh(softplus(x))
|
return x * mx.tanh(softplus(x))
|
||||||
|
|
||||||
|
|
||||||
|
def hardswish(x):
|
||||||
|
r"""Applies the hardswish function, element-wise.
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\text{Hardswish}(x) = x * \min(\max(x + 3, 0), 6) / 6
|
||||||
|
"""
|
||||||
|
max_x_3 = mx.maximum(x + 3, 0)
|
||||||
|
return x * mx.minimum(max_x_3, 6) / 6
|
||||||
|
|
||||||
|
|
||||||
@_make_activation_module(mish)
|
@_make_activation_module(mish)
|
||||||
class Mish(Module):
|
class Mish(Module):
|
||||||
|
r"""Applies the Mish function, element-wise.
|
||||||
|
|
||||||
|
Reference: https://arxiv.org/abs/1908.08681
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
|
||||||
|
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@_make_activation_module(relu)
|
@_make_activation_module(relu)
|
||||||
class ReLU(Module):
|
class ReLU(Module):
|
||||||
|
r"""Applies the Rectified Linear Unit.
|
||||||
|
Simply ``mx.maximum(x, 0)``.
|
||||||
|
|
||||||
|
See :func:`relu`, for the functional equivalent.
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -249,11 +301,37 @@ class ELU(Module):
|
|||||||
|
|
||||||
@_make_activation_module(relu6)
|
@_make_activation_module(relu6)
|
||||||
class ReLU6(Module):
|
class ReLU6(Module):
|
||||||
|
r"""Applies the Rectified Linear Unit 6.
|
||||||
|
|
||||||
|
See :func:`relu6`, for the functional equivalent.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@_make_activation_module(softmax)
|
||||||
|
class Softmax(Module):
|
||||||
|
r"""Applies the Softmax function.
|
||||||
|
|
||||||
|
See :func:`softmax`, for the functional equivalent.
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@_make_activation_module(softplus)
|
@_make_activation_module(softplus)
|
||||||
class Softplus(Module):
|
class Softplus(Module):
|
||||||
|
r"""Applies the Softplus function.
|
||||||
|
|
||||||
|
See :func:`softplus`, for the functional equivalent.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@_make_activation_module(softsign)
|
||||||
|
class Softsign(Module):
|
||||||
|
r"""Applies the Softsign function.
|
||||||
|
|
||||||
|
See :func:`softsign`, for the functional equivalent.
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -278,15 +356,43 @@ class CELU(Module):
|
|||||||
|
|
||||||
@_make_activation_module(silu)
|
@_make_activation_module(silu)
|
||||||
class SiLU(Module):
|
class SiLU(Module):
|
||||||
|
r"""Applies the Sigmoid Linear Unit. Also known as Swish.
|
||||||
|
|
||||||
|
See :func:`silu`, for the functional equivalent.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@_make_activation_module(log_softmax)
|
||||||
|
class LogSoftmax(Module):
|
||||||
|
r"""Applies the Log Softmax function.
|
||||||
|
|
||||||
|
See :func:`log_softmax`, for the functional equivalent.
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@_make_activation_module(log_sigmoid)
|
@_make_activation_module(log_sigmoid)
|
||||||
class LogSigmoid(Module):
|
class LogSigmoid(Module):
|
||||||
|
r"""Applies the Log Sigmoid function.
|
||||||
|
|
||||||
|
See :func:`log_sigmoid`, for the functional equivalent.
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class PReLU(Module):
|
class PReLU(Module):
|
||||||
|
r"""Applies the element-wise parametric ReLU.
|
||||||
|
Applies :math:`\max(0, x) + a * \min(0, x)` element wise, where :math:`a`
|
||||||
|
is an array.
|
||||||
|
|
||||||
|
See :func:`prelu`, for the functional equivalent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_parameters: number of :math:`a` to learn. Default: 1
|
||||||
|
init: the initial value of :math:`a`. Default: 0.25
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, num_parameters=1, init=0.25):
|
def __init__(self, num_parameters=1, init=0.25):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.weight = mx.full([num_parameters], init)
|
self.weight = mx.full([num_parameters], init)
|
||||||
@ -346,6 +452,19 @@ def tanh(x):
|
|||||||
|
|
||||||
@_make_activation_module(tanh)
|
@_make_activation_module(tanh)
|
||||||
class Tanh(Module):
|
class Tanh(Module):
|
||||||
|
r"""Applies the hyperbolic tangent function.
|
||||||
|
|
||||||
|
See :func:`tanh`, for the functional equivalent.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@_make_activation_module(hardswish)
|
||||||
|
class Hardswish(Module):
|
||||||
|
r"""Applies the hardswish function, element-wise.
|
||||||
|
|
||||||
|
See :func:`hardswish`, for the functional equivalent.
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -375,4 +494,8 @@ class Step(Module):
|
|||||||
|
|
||||||
@_make_activation_module(selu)
|
@_make_activation_module(selu)
|
||||||
class SELU(Module):
|
class SELU(Module):
|
||||||
|
r"""Applies the Scaled Exponential Linear Unit.
|
||||||
|
|
||||||
|
See :func:`selu`, for the functional equivalent.
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
@ -667,6 +667,15 @@ class TestNN(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(y.shape, [5])
|
self.assertEqual(y.shape, [5])
|
||||||
self.assertEqual(y.dtype, mx.float32)
|
self.assertEqual(y.dtype, mx.float32)
|
||||||
|
|
||||||
|
def test_softmax(self):
|
||||||
|
x = mx.array([1.0, -1.0, 0.0])
|
||||||
|
y = nn.softmax(x)
|
||||||
|
epsilon = 1e-4
|
||||||
|
expected_y = mx.array([0.6652, 0.0900, 0.2447])
|
||||||
|
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||||
|
self.assertEqual(y.shape, [3])
|
||||||
|
self.assertEqual(y.dtype, mx.float32)
|
||||||
|
|
||||||
def test_softplus(self):
|
def test_softplus(self):
|
||||||
x = mx.array([1.0, -1.0, 0.0])
|
x = mx.array([1.0, -1.0, 0.0])
|
||||||
y = nn.softplus(x)
|
y = nn.softplus(x)
|
||||||
@ -676,6 +685,15 @@ class TestNN(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(y.shape, [3])
|
self.assertEqual(y.shape, [3])
|
||||||
self.assertEqual(y.dtype, mx.float32)
|
self.assertEqual(y.dtype, mx.float32)
|
||||||
|
|
||||||
|
def test_softsign(self):
|
||||||
|
x = mx.array([1.0, -1.0, 0.0])
|
||||||
|
y = nn.softsign(x)
|
||||||
|
epsilon = 1e-4
|
||||||
|
expected_y = mx.array([0.5, -0.5, 0.0])
|
||||||
|
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||||
|
self.assertEqual(y.shape, [3])
|
||||||
|
self.assertEqual(y.dtype, mx.float32)
|
||||||
|
|
||||||
def test_celu(self):
|
def test_celu(self):
|
||||||
x = mx.array([1.0, -1.0, 0.0])
|
x = mx.array([1.0, -1.0, 0.0])
|
||||||
y = nn.celu(x)
|
y = nn.celu(x)
|
||||||
@ -691,6 +709,15 @@ class TestNN(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(y.shape, [3])
|
self.assertEqual(y.shape, [3])
|
||||||
self.assertEqual(y.dtype, mx.float32)
|
self.assertEqual(y.dtype, mx.float32)
|
||||||
|
|
||||||
|
def test_log_softmax(self):
|
||||||
|
x = mx.array([1.0, 2.0, 3.0])
|
||||||
|
y = nn.log_softmax(x)
|
||||||
|
epsilon = 1e-4
|
||||||
|
expected_y = mx.array([-2.4076, -1.4076, -0.4076])
|
||||||
|
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||||
|
self.assertEqual(y.shape, [3])
|
||||||
|
self.assertEqual(y.dtype, mx.float32)
|
||||||
|
|
||||||
def test_log_sigmoid(self):
|
def test_log_sigmoid(self):
|
||||||
x = mx.array([1.0, -1.0, 0.0])
|
x = mx.array([1.0, -1.0, 0.0])
|
||||||
y = nn.log_sigmoid(x)
|
y = nn.log_sigmoid(x)
|
||||||
@ -712,6 +739,15 @@ class TestNN(mlx_tests.MLXTestCase):
|
|||||||
mx.array([0.8651, -0.3034, 0.0000, 0.3752]),
|
mx.array([0.8651, -0.3034, 0.0000, 0.3752]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_hardswish(self):
|
||||||
|
x = mx.array([-3.0, -1.5, 0.0, 1.5, 3.0])
|
||||||
|
y = nn.hardswish(x)
|
||||||
|
epsilon = 1e-4
|
||||||
|
expected_y = mx.array([0.0, -0.375, 0.0, 1.125, 3.0])
|
||||||
|
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||||
|
self.assertEqual(y.shape, [5])
|
||||||
|
self.assertEqual(y.dtype, mx.float32)
|
||||||
|
|
||||||
def test_rope(self):
|
def test_rope(self):
|
||||||
for kwargs in [{}, {"traditional": False}, {"base": 10000}]:
|
for kwargs in [{}, {"traditional": False}, {"base": 10000}]:
|
||||||
rope = nn.RoPE(4, **kwargs)
|
rope = nn.RoPE(4, **kwargs)
|
||||||
|
Loading…
Reference in New Issue
Block a user