mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Added activation functions: leaky_relu relu6 softplus elu celu logsigmoid (#108)
* added leaky_relu relu6 softplus elu celu logsigmoid * minor fixes for docstring and benchmark imports * fixed elu implementation and added tests * added tests for optional param, changed leaky_relu param to fit pytorch documentation
This commit is contained in:
parent
71d1fff90a
commit
b0cd092b7f
@ -6,6 +6,7 @@ import os
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
|
||||||
def int_or_list(x):
|
def int_or_list(x):
|
||||||
@ -99,6 +100,48 @@ def relu(x):
|
|||||||
mx.eval(y)
|
mx.eval(y)
|
||||||
|
|
||||||
|
|
||||||
|
def leaky_relu(x):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = nn.leaky_relu(y)
|
||||||
|
mx.eval(y)
|
||||||
|
|
||||||
|
|
||||||
|
def elu(x):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = nn.elu(y)
|
||||||
|
mx.eval(y)
|
||||||
|
|
||||||
|
|
||||||
|
def relu6(x):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = nn.relu6(y)
|
||||||
|
mx.eval(y)
|
||||||
|
|
||||||
|
|
||||||
|
def softplus(x):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = nn.softplus(y)
|
||||||
|
mx.eval(y)
|
||||||
|
|
||||||
|
|
||||||
|
def celu(x):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = nn.celu(y)
|
||||||
|
mx.eval(y)
|
||||||
|
|
||||||
|
|
||||||
|
def log_sigmoid(x):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = nn.log_sigmoid(y)
|
||||||
|
mx.eval(y)
|
||||||
|
|
||||||
|
|
||||||
def scalar_mult(x):
|
def scalar_mult(x):
|
||||||
y = x
|
y = x
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
@ -277,6 +320,24 @@ if __name__ == "__main__":
|
|||||||
elif args.benchmark == "relu":
|
elif args.benchmark == "relu":
|
||||||
print(bench(relu, x))
|
print(bench(relu, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "leaky_relu":
|
||||||
|
print(bench(leaky_relu, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "elu":
|
||||||
|
print(bench(elu, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "relu6":
|
||||||
|
print(bench(relu6, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "softplus":
|
||||||
|
print(bench(softplus, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "celu":
|
||||||
|
print(bench(celu, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "log_sigmoid":
|
||||||
|
print(bench(log_sigmoid, x))
|
||||||
|
|
||||||
elif args.benchmark == "scalar_mul":
|
elif args.benchmark == "scalar_mul":
|
||||||
print(bench(scalar_mult, x))
|
print(bench(scalar_mult, x))
|
||||||
|
|
||||||
|
@ -115,6 +115,54 @@ def relu(x):
|
|||||||
sync_if_needed(x)
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def leaky_relu(x):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = torch.nn.functional.leaky_relu(y)
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def elu(x):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = torch.nn.functional.elu(y)
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def celu(x):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = torch.nn.functional.celu(y)
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def relu6(x):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = torch.nn.functional.relu6(y)
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def softplus(x):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = torch.nn.functional.softplus(y)
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def log_sigmoid(x):
|
||||||
|
y = x
|
||||||
|
for i in range(100):
|
||||||
|
y = torch.nn.functional.logsigmoid(y)
|
||||||
|
sync_if_needed(x)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def scalar_mult(x):
|
def scalar_mult(x):
|
||||||
y = x
|
y = x
|
||||||
@ -302,6 +350,24 @@ if __name__ == "__main__":
|
|||||||
elif args.benchmark == "relu":
|
elif args.benchmark == "relu":
|
||||||
print(bench(relu, x))
|
print(bench(relu, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "leaky_relu":
|
||||||
|
print(bench(leaky_relu, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "elu":
|
||||||
|
print(bench(elu, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "relu6":
|
||||||
|
print(bench(relu6, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "softplus":
|
||||||
|
print(bench(softplus, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "celu":
|
||||||
|
print(bench(celu, x))
|
||||||
|
|
||||||
|
elif args.benchmark == "log_sigmoid":
|
||||||
|
print(bench(log_sigmoid, x))
|
||||||
|
|
||||||
elif args.benchmark == "scalar_mul":
|
elif args.benchmark == "scalar_mul":
|
||||||
print(bench(scalar_mult, x))
|
print(bench(scalar_mult, x))
|
||||||
|
|
||||||
|
@ -193,6 +193,18 @@ if __name__ == "__main__":
|
|||||||
compare_filtered("softmax --size 2x1024x1024 --axis 1 --fused --cpu")
|
compare_filtered("softmax --size 2x1024x1024 --axis 1 --fused --cpu")
|
||||||
compare_filtered("relu --size 32x16x1024")
|
compare_filtered("relu --size 32x16x1024")
|
||||||
compare_filtered("relu --size 32x16x1024 --cpu")
|
compare_filtered("relu --size 32x16x1024 --cpu")
|
||||||
|
compare_filtered("leaky_relu --size 32x16x1024")
|
||||||
|
compare_filtered("leaky_relu --size 32x16x1024 --cpu")
|
||||||
|
compare_filtered("elu --size 32x16x1024")
|
||||||
|
compare_filtered("elu --size 32x16x1024 --cpu")
|
||||||
|
compare_filtered("relu6 --size 32x16x1024")
|
||||||
|
compare_filtered("relu6 --size 32x16x1024 --cpu")
|
||||||
|
compare_filtered("softplus --size 32x16x1024")
|
||||||
|
compare_filtered("softplus --size 32x16x1024 --cpu")
|
||||||
|
compare_filtered("celu --size 32x16x1024")
|
||||||
|
compare_filtered("celu --size 32x16x1024 --cpu")
|
||||||
|
compare_filtered("log_sigmoid --size 32x16x1024")
|
||||||
|
compare_filtered("log_sigmoid --size 32x16x1024 --cpu")
|
||||||
compare_filtered("scalar_mul --size 32x16x1024")
|
compare_filtered("scalar_mul --size 32x16x1024")
|
||||||
compare_filtered("scalar_mul --size 32x16x1024 --cpu")
|
compare_filtered("scalar_mul --size 32x16x1024 --cpu")
|
||||||
compare_filtered("cross_entropy --size 256x1024")
|
compare_filtered("cross_entropy --size 256x1024")
|
||||||
|
@ -1,14 +1,26 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
from mlx.nn.layers.activations import (
|
from mlx.nn.layers.activations import (
|
||||||
|
CELU,
|
||||||
|
ELU,
|
||||||
GELU,
|
GELU,
|
||||||
|
LeakyReLU,
|
||||||
|
LogSigmoid,
|
||||||
ReLU,
|
ReLU,
|
||||||
|
ReLU6,
|
||||||
SiLU,
|
SiLU,
|
||||||
|
Softplus,
|
||||||
|
celu,
|
||||||
|
elu,
|
||||||
gelu,
|
gelu,
|
||||||
gelu_approx,
|
gelu_approx,
|
||||||
gelu_fast_approx,
|
gelu_fast_approx,
|
||||||
|
leaky_relu,
|
||||||
|
log_sigmoid,
|
||||||
relu,
|
relu,
|
||||||
|
relu6,
|
||||||
silu,
|
silu,
|
||||||
|
softplus,
|
||||||
)
|
)
|
||||||
from mlx.nn.layers.base import Module
|
from mlx.nn.layers.base import Module
|
||||||
from mlx.nn.layers.containers import Sequential
|
from mlx.nn.layers.containers import Sequential
|
||||||
|
@ -32,6 +32,47 @@ def relu(x):
|
|||||||
return mx.maximum(x, 0)
|
return mx.maximum(x, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def leaky_relu(x, negative_slope=0.01):
|
||||||
|
"""Applies the Leaky Rectified Linear Unit.
|
||||||
|
|
||||||
|
Simply ``mx.maximum(negative_slope * x, x)``.
|
||||||
|
"""
|
||||||
|
return mx.maximum(negative_slope * x, x)
|
||||||
|
|
||||||
|
|
||||||
|
def elu(x, alpha=1.0):
|
||||||
|
"""Applies the Exponential Linear Unit.
|
||||||
|
|
||||||
|
Simply ``mx.where(x > 0, x, alpha * (mx.exp(x) - 1))``.
|
||||||
|
"""
|
||||||
|
return mx.where(x > 0, x, alpha * (mx.exp(x) - 1))
|
||||||
|
|
||||||
|
|
||||||
|
def relu6(x):
|
||||||
|
r"""Applies the Rectified Linear Unit 6.
|
||||||
|
|
||||||
|
Applies :math:`\min(\max(x, 0), 6)` element wise.
|
||||||
|
"""
|
||||||
|
return mx.minimum(mx.maximum(x, 0), 6.0)
|
||||||
|
|
||||||
|
|
||||||
|
def softplus(x):
|
||||||
|
r"""Applies the Softplus function.
|
||||||
|
|
||||||
|
Applies :math:`\log(1 + \exp(x))` element wise.
|
||||||
|
"""
|
||||||
|
return mx.logaddexp(x, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def celu(x, alpha=1.0):
|
||||||
|
r"""Applies the Continuously Differentiable Exponential Linear Unit.
|
||||||
|
|
||||||
|
Applies :math:`\max(0, x) + \min(0, \alpha * (\exp(x / \alpha) - 1))`
|
||||||
|
element wise.
|
||||||
|
"""
|
||||||
|
return mx.maximum(x, 0.0) + alpha * (mx.exp(mx.minimum(x, 0.0) / alpha) - 1)
|
||||||
|
|
||||||
|
|
||||||
def silu(x):
|
def silu(x):
|
||||||
r"""Applies the Sigmoid Linear Unit.
|
r"""Applies the Sigmoid Linear Unit.
|
||||||
|
|
||||||
@ -41,6 +82,14 @@ def silu(x):
|
|||||||
return x * mx.sigmoid(x)
|
return x * mx.sigmoid(x)
|
||||||
|
|
||||||
|
|
||||||
|
def log_sigmoid(x):
|
||||||
|
r"""Applies the Log Sigmoid function.
|
||||||
|
|
||||||
|
Applies :math:`\log(\sigma(x)) = -\log(1 + e^{-x})` element wise.
|
||||||
|
"""
|
||||||
|
return -softplus(-x)
|
||||||
|
|
||||||
|
|
||||||
def gelu(x):
|
def gelu(x):
|
||||||
r"""Applies the Gaussian Error Linear Units function.
|
r"""Applies the Gaussian Error Linear Units function.
|
||||||
|
|
||||||
@ -99,11 +148,80 @@ class ReLU(Module):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class LeakyReLU(Module):
|
||||||
|
r"""Applies the Leaky Rectified Linear Unit.
|
||||||
|
|
||||||
|
Simply ``mx.maximum(negative_slope * x, x)``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
negative_slope: Controls the angle of the negative slope. Default: 1e-2.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, negative_slope=1e-2):
|
||||||
|
super().__init__()
|
||||||
|
self._negative_slope = negative_slope
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return leaky_relu(x, self._negative_slope)
|
||||||
|
|
||||||
|
|
||||||
|
class ELU(Module):
|
||||||
|
r"""Applies the Exponential Linear Unit.
|
||||||
|
Simply ``mx.where(x > 0, x, alpha * (mx.exp(x) - 1))``.
|
||||||
|
|
||||||
|
See :func:`elu`, for the functional equivalent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
alpha: the :math:`\alpha` value for the ELU formulation. Default: 1.0
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, alpha=1.0):
|
||||||
|
super().__init__()
|
||||||
|
self._alpha = alpha
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return elu(x, self._alpha)
|
||||||
|
|
||||||
|
|
||||||
|
@_make_activation_module(relu6)
|
||||||
|
class ReLU6(Module):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@_make_activation_module(softplus)
|
||||||
|
class Softplus(Module):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class CELU(Module):
|
||||||
|
r"""Applies the Continuously Differentiable Exponential Linear Unit.
|
||||||
|
Applies :math:`\max(0, x) + \min(0, \alpha * (\exp(x / \alpha) - 1))`
|
||||||
|
element wise.
|
||||||
|
|
||||||
|
See :func:`celu`, for the functional equivalent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
alpha: the :math:`\alpha` value for the CELU formulation. Default: 1.0
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, alpha=1.0):
|
||||||
|
super().__init__()
|
||||||
|
self._alpha = alpha
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return celu(x, self._alpha)
|
||||||
|
|
||||||
|
|
||||||
@_make_activation_module(silu)
|
@_make_activation_module(silu)
|
||||||
class SiLU(Module):
|
class SiLU(Module):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@_make_activation_module(log_sigmoid)
|
||||||
|
class LogSigmoid(Module):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class GELU(Module):
|
class GELU(Module):
|
||||||
r"""Applies the Gaussian Error Linear Units.
|
r"""Applies the Gaussian Error Linear Units.
|
||||||
|
|
||||||
|
@ -303,6 +303,81 @@ class TestNN(mlx_tests.MLXTestCase):
|
|||||||
eq_tree = tree_map(mx.array_equal, m.parameters(), m_load.parameters())
|
eq_tree = tree_map(mx.array_equal, m.parameters(), m_load.parameters())
|
||||||
self.assertTrue(all(tree_flatten(eq_tree)))
|
self.assertTrue(all(tree_flatten(eq_tree)))
|
||||||
|
|
||||||
|
def test_relu(self):
|
||||||
|
x = mx.array([1.0, -1.0, 0.0])
|
||||||
|
y = nn.relu(x)
|
||||||
|
self.assertTrue(mx.array_equal(y, mx.array([1.0, 0.0, 0.0])))
|
||||||
|
self.assertEqual(y.shape, [3])
|
||||||
|
self.assertEqual(y.dtype, mx.float32)
|
||||||
|
|
||||||
|
def test_leaky_relu(self):
|
||||||
|
x = mx.array([1.0, -1.0, 0.0])
|
||||||
|
y = nn.leaky_relu(x)
|
||||||
|
self.assertTrue(mx.array_equal(y, mx.array([1.0, -0.01, 0.0])))
|
||||||
|
self.assertEqual(y.shape, [3])
|
||||||
|
self.assertEqual(y.dtype, mx.float32)
|
||||||
|
|
||||||
|
y = nn.LeakyReLU(negative_slope=0.1)(x)
|
||||||
|
self.assertTrue(mx.array_equal(y, mx.array([1.0, -0.1, 0.0])))
|
||||||
|
self.assertEqual(y.shape, [3])
|
||||||
|
self.assertEqual(y.dtype, mx.float32)
|
||||||
|
|
||||||
|
def test_elu(self):
|
||||||
|
x = mx.array([1.0, -1.0, 0.0])
|
||||||
|
y = nn.elu(x)
|
||||||
|
epsilon = 1e-4
|
||||||
|
expected_y = mx.array([1.0, -0.6321, 0.0])
|
||||||
|
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||||
|
self.assertEqual(y.shape, [3])
|
||||||
|
self.assertEqual(y.dtype, mx.float32)
|
||||||
|
|
||||||
|
y = nn.ELU(alpha=1.1)(x)
|
||||||
|
epsilon = 1e-4
|
||||||
|
expected_y = mx.array([1.0, -0.6953, 0.0])
|
||||||
|
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||||
|
self.assertEqual(y.shape, [3])
|
||||||
|
self.assertEqual(y.dtype, mx.float32)
|
||||||
|
|
||||||
|
def test_relu6(self):
|
||||||
|
x = mx.array([1.0, -1.0, 0.0, 7.0, -7.0])
|
||||||
|
y = nn.relu6(x)
|
||||||
|
self.assertTrue(mx.array_equal(y, mx.array([1.0, 0.0, 0.0, 6.0, 0.0])))
|
||||||
|
self.assertEqual(y.shape, [5])
|
||||||
|
self.assertEqual(y.dtype, mx.float32)
|
||||||
|
|
||||||
|
def test_softplus(self):
|
||||||
|
x = mx.array([1.0, -1.0, 0.0])
|
||||||
|
y = nn.softplus(x)
|
||||||
|
epsilon = 1e-4
|
||||||
|
expected_y = mx.array([1.3133, 0.3133, 0.6931])
|
||||||
|
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||||
|
self.assertEqual(y.shape, [3])
|
||||||
|
self.assertEqual(y.dtype, mx.float32)
|
||||||
|
|
||||||
|
def test_celu(self):
|
||||||
|
x = mx.array([1.0, -1.0, 0.0])
|
||||||
|
y = nn.celu(x)
|
||||||
|
epsilon = 1e-4
|
||||||
|
expected_y = mx.array([1.0, -0.6321, 0.0])
|
||||||
|
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||||
|
self.assertEqual(y.shape, [3])
|
||||||
|
self.assertEqual(y.dtype, mx.float32)
|
||||||
|
|
||||||
|
y = nn.CELU(alpha=1.1)(x)
|
||||||
|
expected_y = mx.array([1.0, -0.6568, 0.0])
|
||||||
|
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||||
|
self.assertEqual(y.shape, [3])
|
||||||
|
self.assertEqual(y.dtype, mx.float32)
|
||||||
|
|
||||||
|
def test_log_sigmoid(self):
|
||||||
|
x = mx.array([1.0, -1.0, 0.0])
|
||||||
|
y = nn.log_sigmoid(x)
|
||||||
|
epsilon = 1e-4
|
||||||
|
expected_y = mx.array([-0.3133, -1.3133, -0.6931])
|
||||||
|
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||||
|
self.assertEqual(y.shape, [3])
|
||||||
|
self.assertEqual(y.dtype, mx.float32)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user