diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index 4b0cea123..fe7b86939 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -19,6 +19,7 @@ MLX was developed with contributions from the following individuals: - Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions. - Paul Paczuski: Improved stability of BCE loss calculation - Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops. +- Gökkdeniz Gülmez: Added `clipped SiLU`. diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index 26f77917f..929d3b032 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -6,6 +6,7 @@ from mlx.nn.layers.activations import ( GELU, GLU, SELU, + ClippedSiLU, HardShrink, Hardswish, HardTanh, @@ -26,6 +27,7 @@ from mlx.nn.layers.activations import ( Step, Tanh, celu, + clipped_silu, elu, gelu, gelu_approx, diff --git a/python/mlx/nn/layers/activations.py b/python/mlx/nn/layers/activations.py index 8eafd75d3..2076b192a 100644 --- a/python/mlx/nn/layers/activations.py +++ b/python/mlx/nn/layers/activations.py @@ -132,6 +132,20 @@ def silu(x): return x * mx.sigmoid(x) +@partial(mx.compile, shapeless=True) +def clipped_silu(x, a_min=-100, a_max=100): + r"""Applies the Clipped Sigmoid Linear Unit. + + Applies :math:`\text{clip}(x \sigma(x), a\_min, a\_max)` element wise, where + :math:`\sigma(\cdot)` is the logistic sigmoid. + + Args: + a_min: minimum value for clipping. Default: ``-100`` + a_max: maximum value for clipping. Default: ``100`` + """ + return mx.clip(x * mx.sigmoid(x), a_min=a_min, a_max=a_max) + + @partial(mx.compile, shapeless=True) def log_sigmoid(x): r"""Applies the Log Sigmoid function. @@ -488,6 +502,25 @@ class SiLU(Module): """ +class ClippedSiLU(Module): + r"""Applies the Clipped Sigmoid Linear Unit. + + See :func:`clipped_silu` for the functional equivalent. + + Args: + a_min: minimum value for clipping. Default: ``-100`` + a_max: maximum value for clipping. Default: ``100`` + """ + + def __init__(self, a_min=-100, a_max=100): + super().__init__() + self.a_min = a_min + self.a_max = a_max + + def __call__(self, x): + return clipped_silu(x, self.a_min, self.a_max) + + @_make_activation_module(log_softmax) class LogSoftmax(Module): r"""Applies the Log Softmax function. diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 10bbe821e..4bec32d2c 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -956,6 +956,43 @@ class TestLayers(mlx_tests.MLXTestCase): self.assertEqual(y.shape, (3,)) self.assertEqual(y.dtype, mx.float32) + def test_silu(self): + x = mx.array([1.0, -1.0, 0.0]) + y = nn.silu(x) + epsilon = 1e-4 + expected_y = mx.array([0.7311, -0.2689, 0.0]) + self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon)) + self.assertEqual(y.shape, (3,)) + self.assertEqual(y.dtype, mx.float32) + + y = nn.SiLU()(x) + self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon)) + self.assertEqual(y.shape, (3,)) + self.assertEqual(y.dtype, mx.float32) + + def test_clipped_silu(self): + x = mx.array([1.0, -1.0, 0.0]) + y = nn.clipped_silu(x, a_min=-100, a_max=100) + epsilon = 1e-4 + expected_y = mx.array([0.7311, -0.2689, 0.0]) + self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon)) + self.assertEqual(y.shape, (3,)) + self.assertEqual(y.dtype, mx.float32) + + y = nn.ClippedSiLU(a_min=-100, a_max=100)(x) + self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon)) + self.assertEqual(y.shape, (3,)) + self.assertEqual(y.dtype, mx.float32) + + x_extreme = mx.array([200.0, -200.0]) + y_clipped = nn.clipped_silu(x_extreme, a_min=-50, a_max=50) + expected_clipped = mx.array([50.0, -50.0]) + self.assertTrue(mx.all(mx.abs(y_clipped - expected_clipped) < epsilon)) + + y_custom = nn.ClippedSiLU(a_min=-10, a_max=10)(x_extreme) + expected_custom = mx.array([10.0, -10.0]) + self.assertTrue(mx.all(mx.abs(y_custom - expected_custom) < epsilon)) + def test_log_softmax(self): x = mx.array([1.0, 2.0, 3.0]) y = nn.log_softmax(x)