This commit is contained in:
Goekdeniz-Guelmez 2025-06-16 22:45:09 +02:00
parent 3713832e5e
commit a315af8981
2 changed files with 5 additions and 5 deletions

View File

@ -6,6 +6,7 @@ from mlx.nn.layers.activations import (
GELU,
GLU,
SELU,
ClippedSiLU,
HardShrink,
Hardswish,
HardTanh,
@ -18,7 +19,6 @@ from mlx.nn.layers.activations import (
ReLU6,
Sigmoid,
SiLU,
ClippedSiLU,
Softmax,
Softmin,
Softplus,
@ -27,6 +27,7 @@ from mlx.nn.layers.activations import (
Step,
Tanh,
celu,
clipped_silu,
elu,
gelu,
gelu_approx,
@ -45,7 +46,6 @@ from mlx.nn.layers.activations import (
selu,
sigmoid,
silu,
clipped_silu,
softmax,
softmin,
softplus,

View File

@ -955,7 +955,7 @@ class TestLayers(mlx_tests.MLXTestCase):
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, (3,))
self.assertEqual(y.dtype, mx.float32)
def test_silu(self):
x = mx.array([1.0, -1.0, 0.0])
y = nn.silu(x)
@ -983,12 +983,12 @@ class TestLayers(mlx_tests.MLXTestCase):
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, (3,))
self.assertEqual(y.dtype, mx.float32)
x_extreme = mx.array([200.0, -200.0])
y_clipped = nn.clipped_silu(x_extreme, a_min=-50, a_max=50)
expected_clipped = mx.array([50.0, -50.0])
self.assertTrue(mx.all(mx.abs(y_clipped - expected_clipped) < epsilon))
y_custom = nn.ClippedSiLU(a_min=-10, a_max=10)(x_extreme)
expected_custom = mx.array([10.0, -10.0])
self.assertTrue(mx.all(mx.abs(y_custom - expected_custom) < epsilon))