mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-18 15:28:16 +08:00
Add Step, ELU, SELU, Swish activation functions (#117)
* Add Step, ELU, SELU, Swish activation functions This commit adds the Step, ELU, SELU and Swish activations functions * add to the docs * review
This commit is contained in:

committed by
GitHub

parent
b9226c367c
commit
f5df47ec6e
@@ -449,6 +449,32 @@ class TestNN(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(y.shape, [3])
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_step_activation(self):
|
||||
x = mx.arange(-3, 4)
|
||||
expected = mx.array([0, 0, 0, 0, 0, 1, 1])
|
||||
y = nn.Step()(x)
|
||||
self.assertTrue(mx.array_equal(y, expected))
|
||||
|
||||
y = nn.Step(2)(x)
|
||||
expected = mx.array([0, 0, 0, 0, 0, 0, 1])
|
||||
self.assertTrue(mx.array_equal(y, expected))
|
||||
|
||||
def test_selu(self):
|
||||
x = mx.arange(-3, 4)
|
||||
expected = mx.array(
|
||||
[
|
||||
-1.670563817024231,
|
||||
-1.5201621055603027,
|
||||
-1.1113275289535522,
|
||||
0.0,
|
||||
1.0506999492645264,
|
||||
2.1013998985290527,
|
||||
3.152099847793579,
|
||||
]
|
||||
)
|
||||
y = nn.SELU()(x)
|
||||
self.assertTrue(mx.allclose(y, expected))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user