test: Add unit test for ReLUSquared activation function

This commit is contained in:
John Mai 2025-06-15 17:07:33 +08:00
parent 940f64fe6a
commit cbd353bf73

View File

@ -855,6 +855,13 @@ class TestLayers(mlx_tests.MLXTestCase):
self.assertEqual(y.shape, (3,))
self.assertEqual(y.dtype, mx.float32)
def test_relu_squared(self):
x = mx.array([-1.0, 0.0, 1.0, 2.0, 3.0])
y = nn.relu_squared(x)
self.assertTrue(mx.array_equal(y, mx.array([0.0, 0.0, 1.0, 4.0, 9.0])))
self.assertEqual(y.shape, (5,))
self.assertEqual(y.dtype, mx.float32)
def test_leaky_relu(self):
x = mx.array([1.0, -1.0, 0.0])
y = nn.leaky_relu(x)