diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 13e31ad96..ca17b20be 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -855,6 +855,13 @@ class TestLayers(mlx_tests.MLXTestCase): self.assertEqual(y.shape, (3,)) self.assertEqual(y.dtype, mx.float32) + def test_relu_squared(self): + x = mx.array([-1.0, 0.0, 1.0, 2.0, 3.0]) + y = nn.relu_squared(x) + self.assertTrue(mx.array_equal(y, mx.array([0.0, 0.0, 1.0, 4.0, 9.0]))) + self.assertEqual(y.shape, (5,)) + self.assertEqual(y.dtype, mx.float32) + def test_leaky_relu(self): x = mx.array([1.0, -1.0, 0.0]) y = nn.leaky_relu(x)