mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Merge 74d6ebd4bd
into b8022c578a
This commit is contained in:
commit
faa5f54f5b
@ -19,6 +19,7 @@ MLX was developed with contributions from the following individuals:
|
|||||||
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
|
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
|
||||||
- Paul Paczuski: Improved stability of BCE loss calculation
|
- Paul Paczuski: Improved stability of BCE loss calculation
|
||||||
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
|
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
|
||||||
|
- Gökkdeniz Gülmez: Added `clipped SiLU`.
|
||||||
|
|
||||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||||
|
@ -6,6 +6,7 @@ from mlx.nn.layers.activations import (
|
|||||||
GELU,
|
GELU,
|
||||||
GLU,
|
GLU,
|
||||||
SELU,
|
SELU,
|
||||||
|
ClippedSiLU,
|
||||||
HardShrink,
|
HardShrink,
|
||||||
Hardswish,
|
Hardswish,
|
||||||
HardTanh,
|
HardTanh,
|
||||||
@ -26,6 +27,7 @@ from mlx.nn.layers.activations import (
|
|||||||
Step,
|
Step,
|
||||||
Tanh,
|
Tanh,
|
||||||
celu,
|
celu,
|
||||||
|
clipped_silu,
|
||||||
elu,
|
elu,
|
||||||
gelu,
|
gelu,
|
||||||
gelu_approx,
|
gelu_approx,
|
||||||
|
@ -132,6 +132,20 @@ def silu(x):
|
|||||||
return x * mx.sigmoid(x)
|
return x * mx.sigmoid(x)
|
||||||
|
|
||||||
|
|
||||||
|
@partial(mx.compile, shapeless=True)
|
||||||
|
def clipped_silu(x, a_min=-100, a_max=100):
|
||||||
|
r"""Applies the Clipped Sigmoid Linear Unit.
|
||||||
|
|
||||||
|
Applies :math:`\text{clip}(x \sigma(x), a\_min, a\_max)` element wise, where
|
||||||
|
:math:`\sigma(\cdot)` is the logistic sigmoid.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a_min: minimum value for clipping. Default: ``-100``
|
||||||
|
a_max: maximum value for clipping. Default: ``100``
|
||||||
|
"""
|
||||||
|
return mx.clip(x * mx.sigmoid(x), a_min=a_min, a_max=a_max)
|
||||||
|
|
||||||
|
|
||||||
@partial(mx.compile, shapeless=True)
|
@partial(mx.compile, shapeless=True)
|
||||||
def log_sigmoid(x):
|
def log_sigmoid(x):
|
||||||
r"""Applies the Log Sigmoid function.
|
r"""Applies the Log Sigmoid function.
|
||||||
@ -488,6 +502,25 @@ class SiLU(Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class ClippedSiLU(Module):
|
||||||
|
r"""Applies the Clipped Sigmoid Linear Unit.
|
||||||
|
|
||||||
|
See :func:`clipped_silu` for the functional equivalent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a_min: minimum value for clipping. Default: ``-100``
|
||||||
|
a_max: maximum value for clipping. Default: ``100``
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, a_min=-100, a_max=100):
|
||||||
|
super().__init__()
|
||||||
|
self.a_min = a_min
|
||||||
|
self.a_max = a_max
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return clipped_silu(x, self.a_min, self.a_max)
|
||||||
|
|
||||||
|
|
||||||
@_make_activation_module(log_softmax)
|
@_make_activation_module(log_softmax)
|
||||||
class LogSoftmax(Module):
|
class LogSoftmax(Module):
|
||||||
r"""Applies the Log Softmax function.
|
r"""Applies the Log Softmax function.
|
||||||
|
@ -956,6 +956,43 @@ class TestLayers(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(y.shape, (3,))
|
self.assertEqual(y.shape, (3,))
|
||||||
self.assertEqual(y.dtype, mx.float32)
|
self.assertEqual(y.dtype, mx.float32)
|
||||||
|
|
||||||
|
def test_silu(self):
|
||||||
|
x = mx.array([1.0, -1.0, 0.0])
|
||||||
|
y = nn.silu(x)
|
||||||
|
epsilon = 1e-4
|
||||||
|
expected_y = mx.array([0.7311, -0.2689, 0.0])
|
||||||
|
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||||
|
self.assertEqual(y.shape, (3,))
|
||||||
|
self.assertEqual(y.dtype, mx.float32)
|
||||||
|
|
||||||
|
y = nn.SiLU()(x)
|
||||||
|
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||||
|
self.assertEqual(y.shape, (3,))
|
||||||
|
self.assertEqual(y.dtype, mx.float32)
|
||||||
|
|
||||||
|
def test_clipped_silu(self):
|
||||||
|
x = mx.array([1.0, -1.0, 0.0])
|
||||||
|
y = nn.clipped_silu(x, a_min=-100, a_max=100)
|
||||||
|
epsilon = 1e-4
|
||||||
|
expected_y = mx.array([0.7311, -0.2689, 0.0])
|
||||||
|
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||||
|
self.assertEqual(y.shape, (3,))
|
||||||
|
self.assertEqual(y.dtype, mx.float32)
|
||||||
|
|
||||||
|
y = nn.ClippedSiLU(a_min=-100, a_max=100)(x)
|
||||||
|
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||||
|
self.assertEqual(y.shape, (3,))
|
||||||
|
self.assertEqual(y.dtype, mx.float32)
|
||||||
|
|
||||||
|
x_extreme = mx.array([200.0, -200.0])
|
||||||
|
y_clipped = nn.clipped_silu(x_extreme, a_min=-50, a_max=50)
|
||||||
|
expected_clipped = mx.array([50.0, -50.0])
|
||||||
|
self.assertTrue(mx.all(mx.abs(y_clipped - expected_clipped) < epsilon))
|
||||||
|
|
||||||
|
y_custom = nn.ClippedSiLU(a_min=-10, a_max=10)(x_extreme)
|
||||||
|
expected_custom = mx.array([10.0, -10.0])
|
||||||
|
self.assertTrue(mx.all(mx.abs(y_custom - expected_custom) < epsilon))
|
||||||
|
|
||||||
def test_log_softmax(self):
|
def test_log_softmax(self):
|
||||||
x = mx.array([1.0, 2.0, 3.0])
|
x = mx.array([1.0, 2.0, 3.0])
|
||||||
y = nn.log_softmax(x)
|
y = nn.log_softmax(x)
|
||||||
|
Loading…
Reference in New Issue
Block a user