mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-16 05:41:14 +08:00
adding test for silu and clipped silu
This commit is contained in:
parent
9cb6df5960
commit
3713832e5e
@ -955,6 +955,43 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||
self.assertEqual(y.shape, (3,))
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_silu(self):
|
||||
x = mx.array([1.0, -1.0, 0.0])
|
||||
y = nn.silu(x)
|
||||
epsilon = 1e-4
|
||||
expected_y = mx.array([0.7311, -0.2689, 0.0])
|
||||
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||
self.assertEqual(y.shape, (3,))
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
y = nn.SiLU()(x)
|
||||
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||
self.assertEqual(y.shape, (3,))
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_clipped_silu(self):
|
||||
x = mx.array([1.0, -1.0, 0.0])
|
||||
y = nn.clipped_silu(x, a_min=-100, a_max=100)
|
||||
epsilon = 1e-4
|
||||
expected_y = mx.array([0.7311, -0.2689, 0.0])
|
||||
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||
self.assertEqual(y.shape, (3,))
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
y = nn.ClippedSiLU(a_min=-100, a_max=100)(x)
|
||||
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||
self.assertEqual(y.shape, (3,))
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
x_extreme = mx.array([200.0, -200.0])
|
||||
y_clipped = nn.clipped_silu(x_extreme, a_min=-50, a_max=50)
|
||||
expected_clipped = mx.array([50.0, -50.0])
|
||||
self.assertTrue(mx.all(mx.abs(y_clipped - expected_clipped) < epsilon))
|
||||
|
||||
y_custom = nn.ClippedSiLU(a_min=-10, a_max=10)(x_extreme)
|
||||
expected_custom = mx.array([10.0, -10.0])
|
||||
self.assertTrue(mx.all(mx.abs(y_custom - expected_custom) < epsilon))
|
||||
|
||||
def test_log_softmax(self):
|
||||
x = mx.array([1.0, 2.0, 3.0])
|
||||
|
Loading…
Reference in New Issue
Block a user