This commit is contained in:
Gökdeniz Gülmez 2025-06-17 19:03:19 +12:00 committed by GitHub
commit faa5f54f5b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 73 additions and 0 deletions

View File

@ -19,6 +19,7 @@ MLX was developed with contributions from the following individuals:
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
- Paul Paczuski: Improved stability of BCE loss calculation
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
- Gökkdeniz Gülmez: Added `clipped SiLU`.
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />

View File

@ -6,6 +6,7 @@ from mlx.nn.layers.activations import (
GELU,
GLU,
SELU,
ClippedSiLU,
HardShrink,
Hardswish,
HardTanh,
@ -26,6 +27,7 @@ from mlx.nn.layers.activations import (
Step,
Tanh,
celu,
clipped_silu,
elu,
gelu,
gelu_approx,

View File

@ -132,6 +132,20 @@ def silu(x):
return x * mx.sigmoid(x)
@partial(mx.compile, shapeless=True)
def clipped_silu(x, a_min=-100, a_max=100):
r"""Applies the Clipped Sigmoid Linear Unit.
Applies :math:`\text{clip}(x \sigma(x), a\_min, a\_max)` element wise, where
:math:`\sigma(\cdot)` is the logistic sigmoid.
Args:
a_min: minimum value for clipping. Default: ``-100``
a_max: maximum value for clipping. Default: ``100``
"""
return mx.clip(x * mx.sigmoid(x), a_min=a_min, a_max=a_max)
@partial(mx.compile, shapeless=True)
def log_sigmoid(x):
r"""Applies the Log Sigmoid function.
@ -488,6 +502,25 @@ class SiLU(Module):
"""
class ClippedSiLU(Module):
r"""Applies the Clipped Sigmoid Linear Unit.
See :func:`clipped_silu` for the functional equivalent.
Args:
a_min: minimum value for clipping. Default: ``-100``
a_max: maximum value for clipping. Default: ``100``
"""
def __init__(self, a_min=-100, a_max=100):
super().__init__()
self.a_min = a_min
self.a_max = a_max
def __call__(self, x):
return clipped_silu(x, self.a_min, self.a_max)
@_make_activation_module(log_softmax)
class LogSoftmax(Module):
r"""Applies the Log Softmax function.

View File

@ -956,6 +956,43 @@ class TestLayers(mlx_tests.MLXTestCase):
self.assertEqual(y.shape, (3,))
self.assertEqual(y.dtype, mx.float32)
def test_silu(self):
x = mx.array([1.0, -1.0, 0.0])
y = nn.silu(x)
epsilon = 1e-4
expected_y = mx.array([0.7311, -0.2689, 0.0])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, (3,))
self.assertEqual(y.dtype, mx.float32)
y = nn.SiLU()(x)
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, (3,))
self.assertEqual(y.dtype, mx.float32)
def test_clipped_silu(self):
x = mx.array([1.0, -1.0, 0.0])
y = nn.clipped_silu(x, a_min=-100, a_max=100)
epsilon = 1e-4
expected_y = mx.array([0.7311, -0.2689, 0.0])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, (3,))
self.assertEqual(y.dtype, mx.float32)
y = nn.ClippedSiLU(a_min=-100, a_max=100)(x)
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, (3,))
self.assertEqual(y.dtype, mx.float32)
x_extreme = mx.array([200.0, -200.0])
y_clipped = nn.clipped_silu(x_extreme, a_min=-50, a_max=50)
expected_clipped = mx.array([50.0, -50.0])
self.assertTrue(mx.all(mx.abs(y_clipped - expected_clipped) < epsilon))
y_custom = nn.ClippedSiLU(a_min=-10, a_max=10)(x_extreme)
expected_custom = mx.array([10.0, -10.0])
self.assertTrue(mx.all(mx.abs(y_custom - expected_custom) < epsilon))
def test_log_softmax(self):
x = mx.array([1.0, 2.0, 3.0])
y = nn.log_softmax(x)