feat: add softsign, softmax, hardswish, logsoftmax activation function (#309)

* feat: add softsign activation function

* run pre-commit

* Add Softsign activation function

* Add Softsign activation function

* Add documentation for ReLU6, Softplus, and Softsign activations

* Update activation functions in neural network layers

* Add LogSoftmax and Hardswish activations

* run pre-commit

* Update activations.py

* Added acknowledgements

* Fix activation function comments

* Fix activation functions in neural network layers
This commit is contained in:
Nripesh Niketan 2023-12-29 23:49:36 +04:00 committed by GitHub
parent 2aedf3e791
commit 5ad8fb7268
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 172 additions and 4 deletions

View File

@ -6,7 +6,8 @@ with a short description of your contribution(s) below. For example:
- Jane Smith: Added the `foo` and `bar` ops.
MLX was developed with contributions from the following individuals:
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions.
- Juarez Bochi: Fixed bug in cross attention.
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
- Diogo Da Cruz: Added tri, tril, triu and safetensor support

View File

@ -5,14 +5,18 @@ from mlx.nn.layers.activations import (
ELU,
GELU,
SELU,
Hardswish,
LeakyReLU,
LogSigmoid,
LogSoftmax,
Mish,
PReLU,
ReLU,
ReLU6,
SiLU,
Softmax,
Softplus,
Softsign,
Step,
Tanh,
celu,
@ -20,15 +24,19 @@ from mlx.nn.layers.activations import (
gelu,
gelu_approx,
gelu_fast_approx,
hardswish,
leaky_relu,
log_sigmoid,
log_softmax,
mish,
prelu,
relu,
relu6,
selu,
silu,
softmax,
softplus,
softsign,
step,
tanh,
)

View File

@ -25,7 +25,7 @@ def sigmoid(x):
def relu(x):
"""Applies the Rectified Linear Unit.
r"""Applies the Rectified Linear Unit.
Simply ``mx.maximum(x, 0)``.
"""
@ -33,15 +33,23 @@ def relu(x):
def leaky_relu(x, negative_slope=0.01):
"""Applies the Leaky Rectified Linear Unit.
r"""Applies the Leaky Rectified Linear Unit.
Simply ``mx.maximum(negative_slope * x, x)``.
"""
return mx.maximum(negative_slope * x, x)
def log_softmax(x, axis=-1):
r"""Applies the Log Softmax function.
Applies :math:`x + \log \sum_i e^{x_i}` element wise.
"""
return x - mx.logsumexp(x, axis=axis, keepdims=True)
def elu(x, alpha=1.0):
"""Applies the Exponential Linear Unit.
r"""Applies the Exponential Linear Unit.
Simply ``mx.where(x > 0, x, alpha * (mx.exp(x) - 1))``.
"""
@ -56,6 +64,14 @@ def relu6(x):
return mx.minimum(mx.maximum(x, 0), 6.0)
def softmax(x, axis=-1):
r"""Applies the Softmax function.
Applies :math:`\frac{e^{x_i}}{\sum_j e^{x_j}}` element wise.
"""
return mx.softmax(x, axis=axis)
def softplus(x):
r"""Applies the Softplus function.
@ -64,6 +80,14 @@ def softplus(x):
return mx.logaddexp(x, 0)
def softsign(x):
r"""Applies the Softsign function.
Applies :math:`\frac{x}{1 + |x|}` element wise.
"""
return mx.divide(x, 1 + mx.abs(x))
def celu(x, alpha=1.0):
r"""Applies the Continuously Differentiable Exponential Linear Unit.
@ -140,6 +164,11 @@ def gelu_fast_approx(x):
@_make_activation_module
class Sigmoid(Module):
r"""Applies the sigmoid function, element-wise.
.. math::
\text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}
"""
pass
@ -202,13 +231,36 @@ def mish(x: mx.array) -> mx.array:
return x * mx.tanh(softplus(x))
def hardswish(x):
r"""Applies the hardswish function, element-wise.
.. math::
\text{Hardswish}(x) = x * \min(\max(x + 3, 0), 6) / 6
"""
max_x_3 = mx.maximum(x + 3, 0)
return x * mx.minimum(max_x_3, 6) / 6
@_make_activation_module(mish)
class Mish(Module):
r"""Applies the Mish function, element-wise.
Reference: https://arxiv.org/abs/1908.08681
.. math::
\text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
"""
pass
@_make_activation_module(relu)
class ReLU(Module):
r"""Applies the Rectified Linear Unit.
Simply ``mx.maximum(x, 0)``.
See :func:`relu`, for the functional equivalent.
"""
pass
@ -249,11 +301,37 @@ class ELU(Module):
@_make_activation_module(relu6)
class ReLU6(Module):
r"""Applies the Rectified Linear Unit 6.
See :func:`relu6`, for the functional equivalent.
"""
pass
@_make_activation_module(softmax)
class Softmax(Module):
r"""Applies the Softmax function.
See :func:`softmax`, for the functional equivalent.
"""
pass
@_make_activation_module(softplus)
class Softplus(Module):
r"""Applies the Softplus function.
See :func:`softplus`, for the functional equivalent.
"""
pass
@_make_activation_module(softsign)
class Softsign(Module):
r"""Applies the Softsign function.
See :func:`softsign`, for the functional equivalent.
"""
pass
@ -278,15 +356,43 @@ class CELU(Module):
@_make_activation_module(silu)
class SiLU(Module):
r"""Applies the Sigmoid Linear Unit. Also known as Swish.
See :func:`silu`, for the functional equivalent.
"""
pass
@_make_activation_module(log_softmax)
class LogSoftmax(Module):
r"""Applies the Log Softmax function.
See :func:`log_softmax`, for the functional equivalent.
"""
pass
@_make_activation_module(log_sigmoid)
class LogSigmoid(Module):
r"""Applies the Log Sigmoid function.
See :func:`log_sigmoid`, for the functional equivalent.
"""
pass
class PReLU(Module):
r"""Applies the element-wise parametric ReLU.
Applies :math:`\max(0, x) + a * \min(0, x)` element wise, where :math:`a`
is an array.
See :func:`prelu`, for the functional equivalent.
Args:
num_parameters: number of :math:`a` to learn. Default: 1
init: the initial value of :math:`a`. Default: 0.25
"""
def __init__(self, num_parameters=1, init=0.25):
super().__init__()
self.weight = mx.full([num_parameters], init)
@ -346,6 +452,19 @@ def tanh(x):
@_make_activation_module(tanh)
class Tanh(Module):
r"""Applies the hyperbolic tangent function.
See :func:`tanh`, for the functional equivalent.
"""
pass
@_make_activation_module(hardswish)
class Hardswish(Module):
r"""Applies the hardswish function, element-wise.
See :func:`hardswish`, for the functional equivalent.
"""
pass
@ -375,4 +494,8 @@ class Step(Module):
@_make_activation_module(selu)
class SELU(Module):
r"""Applies the Scaled Exponential Linear Unit.
See :func:`selu`, for the functional equivalent.
"""
pass

View File

@ -667,6 +667,15 @@ class TestNN(mlx_tests.MLXTestCase):
self.assertEqual(y.shape, [5])
self.assertEqual(y.dtype, mx.float32)
def test_softmax(self):
x = mx.array([1.0, -1.0, 0.0])
y = nn.softmax(x)
epsilon = 1e-4
expected_y = mx.array([0.6652, 0.0900, 0.2447])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, [3])
self.assertEqual(y.dtype, mx.float32)
def test_softplus(self):
x = mx.array([1.0, -1.0, 0.0])
y = nn.softplus(x)
@ -676,6 +685,15 @@ class TestNN(mlx_tests.MLXTestCase):
self.assertEqual(y.shape, [3])
self.assertEqual(y.dtype, mx.float32)
def test_softsign(self):
x = mx.array([1.0, -1.0, 0.0])
y = nn.softsign(x)
epsilon = 1e-4
expected_y = mx.array([0.5, -0.5, 0.0])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, [3])
self.assertEqual(y.dtype, mx.float32)
def test_celu(self):
x = mx.array([1.0, -1.0, 0.0])
y = nn.celu(x)
@ -691,6 +709,15 @@ class TestNN(mlx_tests.MLXTestCase):
self.assertEqual(y.shape, [3])
self.assertEqual(y.dtype, mx.float32)
def test_log_softmax(self):
x = mx.array([1.0, 2.0, 3.0])
y = nn.log_softmax(x)
epsilon = 1e-4
expected_y = mx.array([-2.4076, -1.4076, -0.4076])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, [3])
self.assertEqual(y.dtype, mx.float32)
def test_log_sigmoid(self):
x = mx.array([1.0, -1.0, 0.0])
y = nn.log_sigmoid(x)
@ -712,6 +739,15 @@ class TestNN(mlx_tests.MLXTestCase):
mx.array([0.8651, -0.3034, 0.0000, 0.3752]),
)
def test_hardswish(self):
x = mx.array([-3.0, -1.5, 0.0, 1.5, 3.0])
y = nn.hardswish(x)
epsilon = 1e-4
expected_y = mx.array([0.0, -0.375, 0.0, 1.125, 3.0])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, [5])
self.assertEqual(y.dtype, mx.float32)
def test_rope(self):
for kwargs in [{}, {"traditional": False}, {"base": 10000}]:
rope = nn.RoPE(4, **kwargs)