Added activation functions: leaky_relu relu6 softplus elu celu logsigmoid (#108)

* added leaky_relu relu6 softplus elu celu logsigmoid
* minor fixes for docstring and benchmark imports
* fixed elu implementation and added tests
* added tests for optional param, changed leaky_relu param to fit pytorch documentation
This commit is contained in:
Jason
2023-12-10 19:31:38 -05:00
committed by GitHub
parent 71d1fff90a
commit b0cd092b7f
6 changed files with 344 additions and 0 deletions

View File

@@ -303,6 +303,81 @@ class TestNN(mlx_tests.MLXTestCase):
eq_tree = tree_map(mx.array_equal, m.parameters(), m_load.parameters())
self.assertTrue(all(tree_flatten(eq_tree)))
def test_relu(self):
x = mx.array([1.0, -1.0, 0.0])
y = nn.relu(x)
self.assertTrue(mx.array_equal(y, mx.array([1.0, 0.0, 0.0])))
self.assertEqual(y.shape, [3])
self.assertEqual(y.dtype, mx.float32)
def test_leaky_relu(self):
x = mx.array([1.0, -1.0, 0.0])
y = nn.leaky_relu(x)
self.assertTrue(mx.array_equal(y, mx.array([1.0, -0.01, 0.0])))
self.assertEqual(y.shape, [3])
self.assertEqual(y.dtype, mx.float32)
y = nn.LeakyReLU(negative_slope=0.1)(x)
self.assertTrue(mx.array_equal(y, mx.array([1.0, -0.1, 0.0])))
self.assertEqual(y.shape, [3])
self.assertEqual(y.dtype, mx.float32)
def test_elu(self):
x = mx.array([1.0, -1.0, 0.0])
y = nn.elu(x)
epsilon = 1e-4
expected_y = mx.array([1.0, -0.6321, 0.0])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, [3])
self.assertEqual(y.dtype, mx.float32)
y = nn.ELU(alpha=1.1)(x)
epsilon = 1e-4
expected_y = mx.array([1.0, -0.6953, 0.0])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, [3])
self.assertEqual(y.dtype, mx.float32)
def test_relu6(self):
x = mx.array([1.0, -1.0, 0.0, 7.0, -7.0])
y = nn.relu6(x)
self.assertTrue(mx.array_equal(y, mx.array([1.0, 0.0, 0.0, 6.0, 0.0])))
self.assertEqual(y.shape, [5])
self.assertEqual(y.dtype, mx.float32)
def test_softplus(self):
x = mx.array([1.0, -1.0, 0.0])
y = nn.softplus(x)
epsilon = 1e-4
expected_y = mx.array([1.3133, 0.3133, 0.6931])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, [3])
self.assertEqual(y.dtype, mx.float32)
def test_celu(self):
x = mx.array([1.0, -1.0, 0.0])
y = nn.celu(x)
epsilon = 1e-4
expected_y = mx.array([1.0, -0.6321, 0.0])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, [3])
self.assertEqual(y.dtype, mx.float32)
y = nn.CELU(alpha=1.1)(x)
expected_y = mx.array([1.0, -0.6568, 0.0])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, [3])
self.assertEqual(y.dtype, mx.float32)
def test_log_sigmoid(self):
x = mx.array([1.0, -1.0, 0.0])
y = nn.log_sigmoid(x)
epsilon = 1e-4
expected_y = mx.array([-0.3133, -1.3133, -0.6931])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, [3])
self.assertEqual(y.dtype, mx.float32)
if __name__ == "__main__":
unittest.main()