mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
feat: add softsign, softmax, hardswish, logsoftmax activation function (#309)
* feat: add softsign activation function * run pre-commit * Add Softsign activation function * Add Softsign activation function * Add documentation for ReLU6, Softplus, and Softsign activations * Update activation functions in neural network layers * Add LogSoftmax and Hardswish activations * run pre-commit * Update activations.py * Added acknowledgements * Fix activation function comments * Fix activation functions in neural network layers
This commit is contained in:
@@ -667,6 +667,15 @@ class TestNN(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(y.shape, [5])
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_softmax(self):
|
||||
x = mx.array([1.0, -1.0, 0.0])
|
||||
y = nn.softmax(x)
|
||||
epsilon = 1e-4
|
||||
expected_y = mx.array([0.6652, 0.0900, 0.2447])
|
||||
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||
self.assertEqual(y.shape, [3])
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_softplus(self):
|
||||
x = mx.array([1.0, -1.0, 0.0])
|
||||
y = nn.softplus(x)
|
||||
@@ -676,6 +685,15 @@ class TestNN(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(y.shape, [3])
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_softsign(self):
|
||||
x = mx.array([1.0, -1.0, 0.0])
|
||||
y = nn.softsign(x)
|
||||
epsilon = 1e-4
|
||||
expected_y = mx.array([0.5, -0.5, 0.0])
|
||||
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||
self.assertEqual(y.shape, [3])
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_celu(self):
|
||||
x = mx.array([1.0, -1.0, 0.0])
|
||||
y = nn.celu(x)
|
||||
@@ -691,6 +709,15 @@ class TestNN(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(y.shape, [3])
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_log_softmax(self):
|
||||
x = mx.array([1.0, 2.0, 3.0])
|
||||
y = nn.log_softmax(x)
|
||||
epsilon = 1e-4
|
||||
expected_y = mx.array([-2.4076, -1.4076, -0.4076])
|
||||
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||
self.assertEqual(y.shape, [3])
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_log_sigmoid(self):
|
||||
x = mx.array([1.0, -1.0, 0.0])
|
||||
y = nn.log_sigmoid(x)
|
||||
@@ -712,6 +739,15 @@ class TestNN(mlx_tests.MLXTestCase):
|
||||
mx.array([0.8651, -0.3034, 0.0000, 0.3752]),
|
||||
)
|
||||
|
||||
def test_hardswish(self):
|
||||
x = mx.array([-3.0, -1.5, 0.0, 1.5, 3.0])
|
||||
y = nn.hardswish(x)
|
||||
epsilon = 1e-4
|
||||
expected_y = mx.array([0.0, -0.375, 0.0, 1.125, 3.0])
|
||||
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||
self.assertEqual(y.shape, [5])
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_rope(self):
|
||||
for kwargs in [{}, {"traditional": False}, {"base": 10000}]:
|
||||
rope = nn.RoPE(4, **kwargs)
|
||||
|
Reference in New Issue
Block a user