Added activation functions: leaky_relu relu6 softplus elu celu logsigmoid (#108)

* added leaky_relu relu6 softplus elu celu logsigmoid
* minor fixes for docstring and benchmark imports
* fixed elu implementation and added tests
* added tests for optional param, changed leaky_relu param to fit pytorch documentation
This commit is contained in:
Jason
2023-12-10 19:31:38 -05:00
committed by GitHub
parent 71d1fff90a
commit b0cd092b7f
6 changed files with 344 additions and 0 deletions

View File

@@ -1,14 +1,26 @@
# Copyright © 2023 Apple Inc.
from mlx.nn.layers.activations import (
CELU,
ELU,
GELU,
LeakyReLU,
LogSigmoid,
ReLU,
ReLU6,
SiLU,
Softplus,
celu,
elu,
gelu,
gelu_approx,
gelu_fast_approx,
leaky_relu,
log_sigmoid,
relu,
relu6,
silu,
softplus,
)
from mlx.nn.layers.base import Module
from mlx.nn.layers.containers import Sequential

View File

@@ -32,6 +32,47 @@ def relu(x):
return mx.maximum(x, 0)
def leaky_relu(x, negative_slope=0.01):
"""Applies the Leaky Rectified Linear Unit.
Simply ``mx.maximum(negative_slope * x, x)``.
"""
return mx.maximum(negative_slope * x, x)
def elu(x, alpha=1.0):
"""Applies the Exponential Linear Unit.
Simply ``mx.where(x > 0, x, alpha * (mx.exp(x) - 1))``.
"""
return mx.where(x > 0, x, alpha * (mx.exp(x) - 1))
def relu6(x):
r"""Applies the Rectified Linear Unit 6.
Applies :math:`\min(\max(x, 0), 6)` element wise.
"""
return mx.minimum(mx.maximum(x, 0), 6.0)
def softplus(x):
r"""Applies the Softplus function.
Applies :math:`\log(1 + \exp(x))` element wise.
"""
return mx.logaddexp(x, 0)
def celu(x, alpha=1.0):
r"""Applies the Continuously Differentiable Exponential Linear Unit.
Applies :math:`\max(0, x) + \min(0, \alpha * (\exp(x / \alpha) - 1))`
element wise.
"""
return mx.maximum(x, 0.0) + alpha * (mx.exp(mx.minimum(x, 0.0) / alpha) - 1)
def silu(x):
r"""Applies the Sigmoid Linear Unit.
@@ -41,6 +82,14 @@ def silu(x):
return x * mx.sigmoid(x)
def log_sigmoid(x):
r"""Applies the Log Sigmoid function.
Applies :math:`\log(\sigma(x)) = -\log(1 + e^{-x})` element wise.
"""
return -softplus(-x)
def gelu(x):
r"""Applies the Gaussian Error Linear Units function.
@@ -99,11 +148,80 @@ class ReLU(Module):
pass
class LeakyReLU(Module):
r"""Applies the Leaky Rectified Linear Unit.
Simply ``mx.maximum(negative_slope * x, x)``.
Args:
negative_slope: Controls the angle of the negative slope. Default: 1e-2.
"""
def __init__(self, negative_slope=1e-2):
super().__init__()
self._negative_slope = negative_slope
def __call__(self, x):
return leaky_relu(x, self._negative_slope)
class ELU(Module):
r"""Applies the Exponential Linear Unit.
Simply ``mx.where(x > 0, x, alpha * (mx.exp(x) - 1))``.
See :func:`elu`, for the functional equivalent.
Args:
alpha: the :math:`\alpha` value for the ELU formulation. Default: 1.0
"""
def __init__(self, alpha=1.0):
super().__init__()
self._alpha = alpha
def __call__(self, x):
return elu(x, self._alpha)
@_make_activation_module(relu6)
class ReLU6(Module):
pass
@_make_activation_module(softplus)
class Softplus(Module):
pass
class CELU(Module):
r"""Applies the Continuously Differentiable Exponential Linear Unit.
Applies :math:`\max(0, x) + \min(0, \alpha * (\exp(x / \alpha) - 1))`
element wise.
See :func:`celu`, for the functional equivalent.
Args:
alpha: the :math:`\alpha` value for the CELU formulation. Default: 1.0
"""
def __init__(self, alpha=1.0):
super().__init__()
self._alpha = alpha
def __call__(self, x):
return celu(x, self._alpha)
@_make_activation_module(silu)
class SiLU(Module):
pass
@_make_activation_module(log_sigmoid)
class LogSigmoid(Module):
pass
class GELU(Module):
r"""Applies the Gaussian Error Linear Units.

View File

@@ -303,6 +303,81 @@ class TestNN(mlx_tests.MLXTestCase):
eq_tree = tree_map(mx.array_equal, m.parameters(), m_load.parameters())
self.assertTrue(all(tree_flatten(eq_tree)))
def test_relu(self):
x = mx.array([1.0, -1.0, 0.0])
y = nn.relu(x)
self.assertTrue(mx.array_equal(y, mx.array([1.0, 0.0, 0.0])))
self.assertEqual(y.shape, [3])
self.assertEqual(y.dtype, mx.float32)
def test_leaky_relu(self):
x = mx.array([1.0, -1.0, 0.0])
y = nn.leaky_relu(x)
self.assertTrue(mx.array_equal(y, mx.array([1.0, -0.01, 0.0])))
self.assertEqual(y.shape, [3])
self.assertEqual(y.dtype, mx.float32)
y = nn.LeakyReLU(negative_slope=0.1)(x)
self.assertTrue(mx.array_equal(y, mx.array([1.0, -0.1, 0.0])))
self.assertEqual(y.shape, [3])
self.assertEqual(y.dtype, mx.float32)
def test_elu(self):
x = mx.array([1.0, -1.0, 0.0])
y = nn.elu(x)
epsilon = 1e-4
expected_y = mx.array([1.0, -0.6321, 0.0])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, [3])
self.assertEqual(y.dtype, mx.float32)
y = nn.ELU(alpha=1.1)(x)
epsilon = 1e-4
expected_y = mx.array([1.0, -0.6953, 0.0])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, [3])
self.assertEqual(y.dtype, mx.float32)
def test_relu6(self):
x = mx.array([1.0, -1.0, 0.0, 7.0, -7.0])
y = nn.relu6(x)
self.assertTrue(mx.array_equal(y, mx.array([1.0, 0.0, 0.0, 6.0, 0.0])))
self.assertEqual(y.shape, [5])
self.assertEqual(y.dtype, mx.float32)
def test_softplus(self):
x = mx.array([1.0, -1.0, 0.0])
y = nn.softplus(x)
epsilon = 1e-4
expected_y = mx.array([1.3133, 0.3133, 0.6931])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, [3])
self.assertEqual(y.dtype, mx.float32)
def test_celu(self):
x = mx.array([1.0, -1.0, 0.0])
y = nn.celu(x)
epsilon = 1e-4
expected_y = mx.array([1.0, -0.6321, 0.0])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, [3])
self.assertEqual(y.dtype, mx.float32)
y = nn.CELU(alpha=1.1)(x)
expected_y = mx.array([1.0, -0.6568, 0.0])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, [3])
self.assertEqual(y.dtype, mx.float32)
def test_log_sigmoid(self):
x = mx.array([1.0, -1.0, 0.0])
y = nn.log_sigmoid(x)
epsilon = 1e-4
expected_y = mx.array([-0.3133, -1.3133, -0.6931])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, [3])
self.assertEqual(y.dtype, mx.float32)
if __name__ == "__main__":
unittest.main()