mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
test: Add unit test for ReLUSquared activation function
This commit is contained in:
parent
940f64fe6a
commit
cbd353bf73
@ -855,6 +855,13 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(y.shape, (3,))
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_relu_squared(self):
|
||||
x = mx.array([-1.0, 0.0, 1.0, 2.0, 3.0])
|
||||
y = nn.relu_squared(x)
|
||||
self.assertTrue(mx.array_equal(y, mx.array([0.0, 0.0, 1.0, 4.0, 9.0])))
|
||||
self.assertEqual(y.shape, (5,))
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_leaky_relu(self):
|
||||
x = mx.array([1.0, -1.0, 0.0])
|
||||
y = nn.leaky_relu(x)
|
||||
|
Loading…
Reference in New Issue
Block a user