diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index 6e429db2a..ab3d5e1af 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -6,7 +6,8 @@ with a short description of your contribution(s) below. For example: - Jane Smith: Added the `foo` and `bar` ops. MLX was developed with contributions from the following individuals: - + +- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. - Juarez Bochi: Fixed bug in cross attention. - Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example. - Diogo Da Cruz: Added tri, tril, triu and safetensor support diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index 4dbe96eb6..84d2ceb9f 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -5,14 +5,18 @@ from mlx.nn.layers.activations import ( ELU, GELU, SELU, + Hardswish, LeakyReLU, LogSigmoid, + LogSoftmax, Mish, PReLU, ReLU, ReLU6, SiLU, + Softmax, Softplus, + Softsign, Step, Tanh, celu, @@ -20,15 +24,19 @@ from mlx.nn.layers.activations import ( gelu, gelu_approx, gelu_fast_approx, + hardswish, leaky_relu, log_sigmoid, + log_softmax, mish, prelu, relu, relu6, selu, silu, + softmax, softplus, + softsign, step, tanh, ) diff --git a/python/mlx/nn/layers/activations.py b/python/mlx/nn/layers/activations.py index 20294380c..fcb4483ac 100644 --- a/python/mlx/nn/layers/activations.py +++ b/python/mlx/nn/layers/activations.py @@ -25,7 +25,7 @@ def sigmoid(x): def relu(x): - """Applies the Rectified Linear Unit. + r"""Applies the Rectified Linear Unit. Simply ``mx.maximum(x, 0)``. """ @@ -33,15 +33,23 @@ def relu(x): def leaky_relu(x, negative_slope=0.01): - """Applies the Leaky Rectified Linear Unit. + r"""Applies the Leaky Rectified Linear Unit. Simply ``mx.maximum(negative_slope * x, x)``. """ return mx.maximum(negative_slope * x, x) +def log_softmax(x, axis=-1): + r"""Applies the Log Softmax function. + + Applies :math:`x + \log \sum_i e^{x_i}` element wise. + """ + return x - mx.logsumexp(x, axis=axis, keepdims=True) + + def elu(x, alpha=1.0): - """Applies the Exponential Linear Unit. + r"""Applies the Exponential Linear Unit. Simply ``mx.where(x > 0, x, alpha * (mx.exp(x) - 1))``. """ @@ -56,6 +64,14 @@ def relu6(x): return mx.minimum(mx.maximum(x, 0), 6.0) +def softmax(x, axis=-1): + r"""Applies the Softmax function. + + Applies :math:`\frac{e^{x_i}}{\sum_j e^{x_j}}` element wise. + """ + return mx.softmax(x, axis=axis) + + def softplus(x): r"""Applies the Softplus function. @@ -64,6 +80,14 @@ def softplus(x): return mx.logaddexp(x, 0) +def softsign(x): + r"""Applies the Softsign function. + + Applies :math:`\frac{x}{1 + |x|}` element wise. + """ + return mx.divide(x, 1 + mx.abs(x)) + + def celu(x, alpha=1.0): r"""Applies the Continuously Differentiable Exponential Linear Unit. @@ -140,6 +164,11 @@ def gelu_fast_approx(x): @_make_activation_module class Sigmoid(Module): + r"""Applies the sigmoid function, element-wise. + + .. math:: + \text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)} + """ pass @@ -202,13 +231,36 @@ def mish(x: mx.array) -> mx.array: return x * mx.tanh(softplus(x)) +def hardswish(x): + r"""Applies the hardswish function, element-wise. + + .. math:: + \text{Hardswish}(x) = x * \min(\max(x + 3, 0), 6) / 6 + """ + max_x_3 = mx.maximum(x + 3, 0) + return x * mx.minimum(max_x_3, 6) / 6 + + @_make_activation_module(mish) class Mish(Module): + r"""Applies the Mish function, element-wise. + + Reference: https://arxiv.org/abs/1908.08681 + + .. math:: + \text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x)) + + """ pass @_make_activation_module(relu) class ReLU(Module): + r"""Applies the Rectified Linear Unit. + Simply ``mx.maximum(x, 0)``. + + See :func:`relu`, for the functional equivalent. + """ pass @@ -249,11 +301,37 @@ class ELU(Module): @_make_activation_module(relu6) class ReLU6(Module): + r"""Applies the Rectified Linear Unit 6. + + See :func:`relu6`, for the functional equivalent. + """ + pass + + +@_make_activation_module(softmax) +class Softmax(Module): + r"""Applies the Softmax function. + + See :func:`softmax`, for the functional equivalent. + """ pass @_make_activation_module(softplus) class Softplus(Module): + r"""Applies the Softplus function. + + See :func:`softplus`, for the functional equivalent. + """ + pass + + +@_make_activation_module(softsign) +class Softsign(Module): + r"""Applies the Softsign function. + + See :func:`softsign`, for the functional equivalent. + """ pass @@ -278,15 +356,43 @@ class CELU(Module): @_make_activation_module(silu) class SiLU(Module): + r"""Applies the Sigmoid Linear Unit. Also known as Swish. + + See :func:`silu`, for the functional equivalent. + """ + pass + + +@_make_activation_module(log_softmax) +class LogSoftmax(Module): + r"""Applies the Log Softmax function. + + See :func:`log_softmax`, for the functional equivalent. + """ pass @_make_activation_module(log_sigmoid) class LogSigmoid(Module): + r"""Applies the Log Sigmoid function. + + See :func:`log_sigmoid`, for the functional equivalent. + """ pass class PReLU(Module): + r"""Applies the element-wise parametric ReLU. + Applies :math:`\max(0, x) + a * \min(0, x)` element wise, where :math:`a` + is an array. + + See :func:`prelu`, for the functional equivalent. + + Args: + num_parameters: number of :math:`a` to learn. Default: 1 + init: the initial value of :math:`a`. Default: 0.25 + """ + def __init__(self, num_parameters=1, init=0.25): super().__init__() self.weight = mx.full([num_parameters], init) @@ -346,6 +452,19 @@ def tanh(x): @_make_activation_module(tanh) class Tanh(Module): + r"""Applies the hyperbolic tangent function. + + See :func:`tanh`, for the functional equivalent. + """ + pass + + +@_make_activation_module(hardswish) +class Hardswish(Module): + r"""Applies the hardswish function, element-wise. + + See :func:`hardswish`, for the functional equivalent. + """ pass @@ -375,4 +494,8 @@ class Step(Module): @_make_activation_module(selu) class SELU(Module): + r"""Applies the Scaled Exponential Linear Unit. + + See :func:`selu`, for the functional equivalent. + """ pass diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 8210b4a48..204145c01 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -667,6 +667,15 @@ class TestNN(mlx_tests.MLXTestCase): self.assertEqual(y.shape, [5]) self.assertEqual(y.dtype, mx.float32) + def test_softmax(self): + x = mx.array([1.0, -1.0, 0.0]) + y = nn.softmax(x) + epsilon = 1e-4 + expected_y = mx.array([0.6652, 0.0900, 0.2447]) + self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon)) + self.assertEqual(y.shape, [3]) + self.assertEqual(y.dtype, mx.float32) + def test_softplus(self): x = mx.array([1.0, -1.0, 0.0]) y = nn.softplus(x) @@ -676,6 +685,15 @@ class TestNN(mlx_tests.MLXTestCase): self.assertEqual(y.shape, [3]) self.assertEqual(y.dtype, mx.float32) + def test_softsign(self): + x = mx.array([1.0, -1.0, 0.0]) + y = nn.softsign(x) + epsilon = 1e-4 + expected_y = mx.array([0.5, -0.5, 0.0]) + self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon)) + self.assertEqual(y.shape, [3]) + self.assertEqual(y.dtype, mx.float32) + def test_celu(self): x = mx.array([1.0, -1.0, 0.0]) y = nn.celu(x) @@ -691,6 +709,15 @@ class TestNN(mlx_tests.MLXTestCase): self.assertEqual(y.shape, [3]) self.assertEqual(y.dtype, mx.float32) + def test_log_softmax(self): + x = mx.array([1.0, 2.0, 3.0]) + y = nn.log_softmax(x) + epsilon = 1e-4 + expected_y = mx.array([-2.4076, -1.4076, -0.4076]) + self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon)) + self.assertEqual(y.shape, [3]) + self.assertEqual(y.dtype, mx.float32) + def test_log_sigmoid(self): x = mx.array([1.0, -1.0, 0.0]) y = nn.log_sigmoid(x) @@ -712,6 +739,15 @@ class TestNN(mlx_tests.MLXTestCase): mx.array([0.8651, -0.3034, 0.0000, 0.3752]), ) + def test_hardswish(self): + x = mx.array([-3.0, -1.5, 0.0, 1.5, 3.0]) + y = nn.hardswish(x) + epsilon = 1e-4 + expected_y = mx.array([0.0, -0.375, 0.0, 1.125, 3.0]) + self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon)) + self.assertEqual(y.shape, [5]) + self.assertEqual(y.dtype, mx.float32) + def test_rope(self): for kwargs in [{}, {"traditional": False}, {"base": 10000}]: rope = nn.RoPE(4, **kwargs)