This commit is contained in:
Goekdeniz-Guelmez 2025-06-16 22:45:09 +02:00
parent 3713832e5e
commit a315af8981
2 changed files with 5 additions and 5 deletions

View File

@ -6,6 +6,7 @@ from mlx.nn.layers.activations import (
GELU, GELU,
GLU, GLU,
SELU, SELU,
ClippedSiLU,
HardShrink, HardShrink,
Hardswish, Hardswish,
HardTanh, HardTanh,
@ -18,7 +19,6 @@ from mlx.nn.layers.activations import (
ReLU6, ReLU6,
Sigmoid, Sigmoid,
SiLU, SiLU,
ClippedSiLU,
Softmax, Softmax,
Softmin, Softmin,
Softplus, Softplus,
@ -27,6 +27,7 @@ from mlx.nn.layers.activations import (
Step, Step,
Tanh, Tanh,
celu, celu,
clipped_silu,
elu, elu,
gelu, gelu,
gelu_approx, gelu_approx,
@ -45,7 +46,6 @@ from mlx.nn.layers.activations import (
selu, selu,
sigmoid, sigmoid,
silu, silu,
clipped_silu,
softmax, softmax,
softmin, softmin,
softplus, softplus,

View File

@ -955,7 +955,7 @@ class TestLayers(mlx_tests.MLXTestCase):
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon)) self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, (3,)) self.assertEqual(y.shape, (3,))
self.assertEqual(y.dtype, mx.float32) self.assertEqual(y.dtype, mx.float32)
def test_silu(self): def test_silu(self):
x = mx.array([1.0, -1.0, 0.0]) x = mx.array([1.0, -1.0, 0.0])
y = nn.silu(x) y = nn.silu(x)
@ -983,12 +983,12 @@ class TestLayers(mlx_tests.MLXTestCase):
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon)) self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, (3,)) self.assertEqual(y.shape, (3,))
self.assertEqual(y.dtype, mx.float32) self.assertEqual(y.dtype, mx.float32)
x_extreme = mx.array([200.0, -200.0]) x_extreme = mx.array([200.0, -200.0])
y_clipped = nn.clipped_silu(x_extreme, a_min=-50, a_max=50) y_clipped = nn.clipped_silu(x_extreme, a_min=-50, a_max=50)
expected_clipped = mx.array([50.0, -50.0]) expected_clipped = mx.array([50.0, -50.0])
self.assertTrue(mx.all(mx.abs(y_clipped - expected_clipped) < epsilon)) self.assertTrue(mx.all(mx.abs(y_clipped - expected_clipped) < epsilon))
y_custom = nn.ClippedSiLU(a_min=-10, a_max=10)(x_extreme) y_custom = nn.ClippedSiLU(a_min=-10, a_max=10)(x_extreme)
expected_custom = mx.array([10.0, -10.0]) expected_custom = mx.array([10.0, -10.0])
self.assertTrue(mx.all(mx.abs(y_custom - expected_custom) < epsilon)) self.assertTrue(mx.all(mx.abs(y_custom - expected_custom) < epsilon))