mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Merge 74d6ebd4bd
into b8022c578a
This commit is contained in:
commit
faa5f54f5b
@ -19,6 +19,7 @@ MLX was developed with contributions from the following individuals:
|
||||
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
|
||||
- Paul Paczuski: Improved stability of BCE loss calculation
|
||||
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
|
||||
- Gökkdeniz Gülmez: Added `clipped SiLU`.
|
||||
|
||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||
|
@ -6,6 +6,7 @@ from mlx.nn.layers.activations import (
|
||||
GELU,
|
||||
GLU,
|
||||
SELU,
|
||||
ClippedSiLU,
|
||||
HardShrink,
|
||||
Hardswish,
|
||||
HardTanh,
|
||||
@ -26,6 +27,7 @@ from mlx.nn.layers.activations import (
|
||||
Step,
|
||||
Tanh,
|
||||
celu,
|
||||
clipped_silu,
|
||||
elu,
|
||||
gelu,
|
||||
gelu_approx,
|
||||
|
@ -132,6 +132,20 @@ def silu(x):
|
||||
return x * mx.sigmoid(x)
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def clipped_silu(x, a_min=-100, a_max=100):
|
||||
r"""Applies the Clipped Sigmoid Linear Unit.
|
||||
|
||||
Applies :math:`\text{clip}(x \sigma(x), a\_min, a\_max)` element wise, where
|
||||
:math:`\sigma(\cdot)` is the logistic sigmoid.
|
||||
|
||||
Args:
|
||||
a_min: minimum value for clipping. Default: ``-100``
|
||||
a_max: maximum value for clipping. Default: ``100``
|
||||
"""
|
||||
return mx.clip(x * mx.sigmoid(x), a_min=a_min, a_max=a_max)
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def log_sigmoid(x):
|
||||
r"""Applies the Log Sigmoid function.
|
||||
@ -488,6 +502,25 @@ class SiLU(Module):
|
||||
"""
|
||||
|
||||
|
||||
class ClippedSiLU(Module):
|
||||
r"""Applies the Clipped Sigmoid Linear Unit.
|
||||
|
||||
See :func:`clipped_silu` for the functional equivalent.
|
||||
|
||||
Args:
|
||||
a_min: minimum value for clipping. Default: ``-100``
|
||||
a_max: maximum value for clipping. Default: ``100``
|
||||
"""
|
||||
|
||||
def __init__(self, a_min=-100, a_max=100):
|
||||
super().__init__()
|
||||
self.a_min = a_min
|
||||
self.a_max = a_max
|
||||
|
||||
def __call__(self, x):
|
||||
return clipped_silu(x, self.a_min, self.a_max)
|
||||
|
||||
|
||||
@_make_activation_module(log_softmax)
|
||||
class LogSoftmax(Module):
|
||||
r"""Applies the Log Softmax function.
|
||||
|
@ -956,6 +956,43 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(y.shape, (3,))
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_silu(self):
|
||||
x = mx.array([1.0, -1.0, 0.0])
|
||||
y = nn.silu(x)
|
||||
epsilon = 1e-4
|
||||
expected_y = mx.array([0.7311, -0.2689, 0.0])
|
||||
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||
self.assertEqual(y.shape, (3,))
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
y = nn.SiLU()(x)
|
||||
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||
self.assertEqual(y.shape, (3,))
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_clipped_silu(self):
|
||||
x = mx.array([1.0, -1.0, 0.0])
|
||||
y = nn.clipped_silu(x, a_min=-100, a_max=100)
|
||||
epsilon = 1e-4
|
||||
expected_y = mx.array([0.7311, -0.2689, 0.0])
|
||||
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||
self.assertEqual(y.shape, (3,))
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
y = nn.ClippedSiLU(a_min=-100, a_max=100)(x)
|
||||
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||
self.assertEqual(y.shape, (3,))
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
x_extreme = mx.array([200.0, -200.0])
|
||||
y_clipped = nn.clipped_silu(x_extreme, a_min=-50, a_max=50)
|
||||
expected_clipped = mx.array([50.0, -50.0])
|
||||
self.assertTrue(mx.all(mx.abs(y_clipped - expected_clipped) < epsilon))
|
||||
|
||||
y_custom = nn.ClippedSiLU(a_min=-10, a_max=10)(x_extreme)
|
||||
expected_custom = mx.array([10.0, -10.0])
|
||||
self.assertTrue(mx.all(mx.abs(y_custom - expected_custom) < epsilon))
|
||||
|
||||
def test_log_softmax(self):
|
||||
x = mx.array([1.0, 2.0, 3.0])
|
||||
y = nn.log_softmax(x)
|
||||
|
Loading…
Reference in New Issue
Block a user