From cbd353bf73794c27cbcde672a7a32abbeaf80238 Mon Sep 17 00:00:00 2001 From: John Mai Date: Sun, 15 Jun 2025 17:07:33 +0800 Subject: [PATCH] test: Add unit test for ReLUSquared activation function --- python/tests/test_nn.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 13e31ad96..ca17b20be 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -855,6 +855,13 @@ class TestLayers(mlx_tests.MLXTestCase): self.assertEqual(y.shape, (3,)) self.assertEqual(y.dtype, mx.float32) + def test_relu_squared(self): + x = mx.array([-1.0, 0.0, 1.0, 2.0, 3.0]) + y = nn.relu_squared(x) + self.assertTrue(mx.array_equal(y, mx.array([0.0, 0.0, 1.0, 4.0, 9.0]))) + self.assertEqual(y.shape, (5,)) + self.assertEqual(y.dtype, mx.float32) + def test_leaky_relu(self): x = mx.array([1.0, -1.0, 0.0]) y = nn.leaky_relu(x)