diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index 205a755fa..929d3b032 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -6,6 +6,7 @@ from mlx.nn.layers.activations import ( GELU, GLU, SELU, + ClippedSiLU, HardShrink, Hardswish, HardTanh, @@ -18,7 +19,6 @@ from mlx.nn.layers.activations import ( ReLU6, Sigmoid, SiLU, - ClippedSiLU, Softmax, Softmin, Softplus, @@ -27,6 +27,7 @@ from mlx.nn.layers.activations import ( Step, Tanh, celu, + clipped_silu, elu, gelu, gelu_approx, @@ -45,7 +46,6 @@ from mlx.nn.layers.activations import ( selu, sigmoid, silu, - clipped_silu, softmax, softmin, softplus, diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 2934dbd9d..4bec32d2c 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -955,7 +955,7 @@ class TestLayers(mlx_tests.MLXTestCase): self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon)) self.assertEqual(y.shape, (3,)) self.assertEqual(y.dtype, mx.float32) - + def test_silu(self): x = mx.array([1.0, -1.0, 0.0]) y = nn.silu(x) @@ -983,12 +983,12 @@ class TestLayers(mlx_tests.MLXTestCase): self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon)) self.assertEqual(y.shape, (3,)) self.assertEqual(y.dtype, mx.float32) - + x_extreme = mx.array([200.0, -200.0]) y_clipped = nn.clipped_silu(x_extreme, a_min=-50, a_max=50) expected_clipped = mx.array([50.0, -50.0]) self.assertTrue(mx.all(mx.abs(y_clipped - expected_clipped) < epsilon)) - + y_custom = nn.ClippedSiLU(a_min=-10, a_max=10)(x_extreme) expected_custom = mx.array([10.0, -10.0]) self.assertTrue(mx.all(mx.abs(y_custom - expected_custom) < epsilon))