mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
Added activation functions: leaky_relu relu6 softplus elu celu logsigmoid (#108)
* added leaky_relu relu6 softplus elu celu logsigmoid * minor fixes for docstring and benchmark imports * fixed elu implementation and added tests * added tests for optional param, changed leaky_relu param to fit pytorch documentation
This commit is contained in:
@@ -303,6 +303,81 @@ class TestNN(mlx_tests.MLXTestCase):
|
||||
eq_tree = tree_map(mx.array_equal, m.parameters(), m_load.parameters())
|
||||
self.assertTrue(all(tree_flatten(eq_tree)))
|
||||
|
||||
def test_relu(self):
|
||||
x = mx.array([1.0, -1.0, 0.0])
|
||||
y = nn.relu(x)
|
||||
self.assertTrue(mx.array_equal(y, mx.array([1.0, 0.0, 0.0])))
|
||||
self.assertEqual(y.shape, [3])
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_leaky_relu(self):
|
||||
x = mx.array([1.0, -1.0, 0.0])
|
||||
y = nn.leaky_relu(x)
|
||||
self.assertTrue(mx.array_equal(y, mx.array([1.0, -0.01, 0.0])))
|
||||
self.assertEqual(y.shape, [3])
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
y = nn.LeakyReLU(negative_slope=0.1)(x)
|
||||
self.assertTrue(mx.array_equal(y, mx.array([1.0, -0.1, 0.0])))
|
||||
self.assertEqual(y.shape, [3])
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_elu(self):
|
||||
x = mx.array([1.0, -1.0, 0.0])
|
||||
y = nn.elu(x)
|
||||
epsilon = 1e-4
|
||||
expected_y = mx.array([1.0, -0.6321, 0.0])
|
||||
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||
self.assertEqual(y.shape, [3])
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
y = nn.ELU(alpha=1.1)(x)
|
||||
epsilon = 1e-4
|
||||
expected_y = mx.array([1.0, -0.6953, 0.0])
|
||||
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||
self.assertEqual(y.shape, [3])
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_relu6(self):
|
||||
x = mx.array([1.0, -1.0, 0.0, 7.0, -7.0])
|
||||
y = nn.relu6(x)
|
||||
self.assertTrue(mx.array_equal(y, mx.array([1.0, 0.0, 0.0, 6.0, 0.0])))
|
||||
self.assertEqual(y.shape, [5])
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_softplus(self):
|
||||
x = mx.array([1.0, -1.0, 0.0])
|
||||
y = nn.softplus(x)
|
||||
epsilon = 1e-4
|
||||
expected_y = mx.array([1.3133, 0.3133, 0.6931])
|
||||
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||
self.assertEqual(y.shape, [3])
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_celu(self):
|
||||
x = mx.array([1.0, -1.0, 0.0])
|
||||
y = nn.celu(x)
|
||||
epsilon = 1e-4
|
||||
expected_y = mx.array([1.0, -0.6321, 0.0])
|
||||
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||
self.assertEqual(y.shape, [3])
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
y = nn.CELU(alpha=1.1)(x)
|
||||
expected_y = mx.array([1.0, -0.6568, 0.0])
|
||||
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||
self.assertEqual(y.shape, [3])
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_log_sigmoid(self):
|
||||
x = mx.array([1.0, -1.0, 0.0])
|
||||
y = nn.log_sigmoid(x)
|
||||
epsilon = 1e-4
|
||||
expected_y = mx.array([-0.3133, -1.3133, -0.6931])
|
||||
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||
self.assertEqual(y.shape, [3])
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user