diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 10bbe821e..2934dbd9d 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -955,6 +955,43 @@ class TestLayers(mlx_tests.MLXTestCase): self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon)) self.assertEqual(y.shape, (3,)) self.assertEqual(y.dtype, mx.float32) + + def test_silu(self): + x = mx.array([1.0, -1.0, 0.0]) + y = nn.silu(x) + epsilon = 1e-4 + expected_y = mx.array([0.7311, -0.2689, 0.0]) + self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon)) + self.assertEqual(y.shape, (3,)) + self.assertEqual(y.dtype, mx.float32) + + y = nn.SiLU()(x) + self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon)) + self.assertEqual(y.shape, (3,)) + self.assertEqual(y.dtype, mx.float32) + + def test_clipped_silu(self): + x = mx.array([1.0, -1.0, 0.0]) + y = nn.clipped_silu(x, a_min=-100, a_max=100) + epsilon = 1e-4 + expected_y = mx.array([0.7311, -0.2689, 0.0]) + self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon)) + self.assertEqual(y.shape, (3,)) + self.assertEqual(y.dtype, mx.float32) + + y = nn.ClippedSiLU(a_min=-100, a_max=100)(x) + self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon)) + self.assertEqual(y.shape, (3,)) + self.assertEqual(y.dtype, mx.float32) + + x_extreme = mx.array([200.0, -200.0]) + y_clipped = nn.clipped_silu(x_extreme, a_min=-50, a_max=50) + expected_clipped = mx.array([50.0, -50.0]) + self.assertTrue(mx.all(mx.abs(y_clipped - expected_clipped) < epsilon)) + + y_custom = nn.ClippedSiLU(a_min=-10, a_max=10)(x_extreme) + expected_custom = mx.array([10.0, -10.0]) + self.assertTrue(mx.all(mx.abs(y_custom - expected_custom) < epsilon)) def test_log_softmax(self): x = mx.array([1.0, 2.0, 3.0])