Add Step, ELU, SELU, Swish activation functions (#117)

* Add Step, ELU, SELU, Swish activation functions

This commit adds the Step, ELU, SELU and Swish activations functions

* add to the docs

* review
This commit is contained in:
Nicholas Santavas
2023-12-12 02:04:07 +01:00
committed by GitHub
parent b9226c367c
commit f5df47ec6e
7 changed files with 132 additions and 2 deletions

View File

@@ -449,6 +449,32 @@ class TestNN(mlx_tests.MLXTestCase):
self.assertEqual(y.shape, [3])
self.assertEqual(y.dtype, mx.float32)
def test_step_activation(self):
x = mx.arange(-3, 4)
expected = mx.array([0, 0, 0, 0, 0, 1, 1])
y = nn.Step()(x)
self.assertTrue(mx.array_equal(y, expected))
y = nn.Step(2)(x)
expected = mx.array([0, 0, 0, 0, 0, 0, 1])
self.assertTrue(mx.array_equal(y, expected))
def test_selu(self):
x = mx.arange(-3, 4)
expected = mx.array(
[
-1.670563817024231,
-1.5201621055603027,
-1.1113275289535522,
0.0,
1.0506999492645264,
2.1013998985290527,
3.152099847793579,
]
)
y = nn.SELU()(x)
self.assertTrue(mx.allclose(y, expected))
if __name__ == "__main__":
unittest.main()