diff --git a/benchmarks/python/comparative/bench_mlx.py b/benchmarks/python/comparative/bench_mlx.py index 5f2c3619f..d73610b17 100644 --- a/benchmarks/python/comparative/bench_mlx.py +++ b/benchmarks/python/comparative/bench_mlx.py @@ -6,6 +6,7 @@ import os import time import mlx.core as mx +import mlx.nn as nn def int_or_list(x): @@ -99,6 +100,48 @@ def relu(x): mx.eval(y) +def leaky_relu(x): + y = x + for i in range(100): + y = nn.leaky_relu(y) + mx.eval(y) + + +def elu(x): + y = x + for i in range(100): + y = nn.elu(y) + mx.eval(y) + + +def relu6(x): + y = x + for i in range(100): + y = nn.relu6(y) + mx.eval(y) + + +def softplus(x): + y = x + for i in range(100): + y = nn.softplus(y) + mx.eval(y) + + +def celu(x): + y = x + for i in range(100): + y = nn.celu(y) + mx.eval(y) + + +def log_sigmoid(x): + y = x + for i in range(100): + y = nn.log_sigmoid(y) + mx.eval(y) + + def scalar_mult(x): y = x for i in range(100): @@ -277,6 +320,24 @@ if __name__ == "__main__": elif args.benchmark == "relu": print(bench(relu, x)) + elif args.benchmark == "leaky_relu": + print(bench(leaky_relu, x)) + + elif args.benchmark == "elu": + print(bench(elu, x)) + + elif args.benchmark == "relu6": + print(bench(relu6, x)) + + elif args.benchmark == "softplus": + print(bench(softplus, x)) + + elif args.benchmark == "celu": + print(bench(celu, x)) + + elif args.benchmark == "log_sigmoid": + print(bench(log_sigmoid, x)) + elif args.benchmark == "scalar_mul": print(bench(scalar_mult, x)) diff --git a/benchmarks/python/comparative/bench_torch.py b/benchmarks/python/comparative/bench_torch.py index f275b9da4..1c92bd7a1 100644 --- a/benchmarks/python/comparative/bench_torch.py +++ b/benchmarks/python/comparative/bench_torch.py @@ -115,6 +115,54 @@ def relu(x): sync_if_needed(x) +@torch.no_grad() +def leaky_relu(x): + y = x + for i in range(100): + y = torch.nn.functional.leaky_relu(y) + sync_if_needed(x) + + +@torch.no_grad() +def elu(x): + y = x + for i in range(100): + y = torch.nn.functional.elu(y) + sync_if_needed(x) + + +@torch.no_grad() +def celu(x): + y = x + for i in range(100): + y = torch.nn.functional.celu(y) + sync_if_needed(x) + + +@torch.no_grad() +def relu6(x): + y = x + for i in range(100): + y = torch.nn.functional.relu6(y) + sync_if_needed(x) + + +@torch.no_grad() +def softplus(x): + y = x + for i in range(100): + y = torch.nn.functional.softplus(y) + sync_if_needed(x) + + +@torch.no_grad() +def log_sigmoid(x): + y = x + for i in range(100): + y = torch.nn.functional.logsigmoid(y) + sync_if_needed(x) + + @torch.no_grad() def scalar_mult(x): y = x @@ -302,6 +350,24 @@ if __name__ == "__main__": elif args.benchmark == "relu": print(bench(relu, x)) + elif args.benchmark == "leaky_relu": + print(bench(leaky_relu, x)) + + elif args.benchmark == "elu": + print(bench(elu, x)) + + elif args.benchmark == "relu6": + print(bench(relu6, x)) + + elif args.benchmark == "softplus": + print(bench(softplus, x)) + + elif args.benchmark == "celu": + print(bench(celu, x)) + + elif args.benchmark == "log_sigmoid": + print(bench(log_sigmoid, x)) + elif args.benchmark == "scalar_mul": print(bench(scalar_mult, x)) diff --git a/benchmarks/python/comparative/compare.py b/benchmarks/python/comparative/compare.py index f357165fd..3e64625e5 100644 --- a/benchmarks/python/comparative/compare.py +++ b/benchmarks/python/comparative/compare.py @@ -193,6 +193,18 @@ if __name__ == "__main__": compare_filtered("softmax --size 2x1024x1024 --axis 1 --fused --cpu") compare_filtered("relu --size 32x16x1024") compare_filtered("relu --size 32x16x1024 --cpu") + compare_filtered("leaky_relu --size 32x16x1024") + compare_filtered("leaky_relu --size 32x16x1024 --cpu") + compare_filtered("elu --size 32x16x1024") + compare_filtered("elu --size 32x16x1024 --cpu") + compare_filtered("relu6 --size 32x16x1024") + compare_filtered("relu6 --size 32x16x1024 --cpu") + compare_filtered("softplus --size 32x16x1024") + compare_filtered("softplus --size 32x16x1024 --cpu") + compare_filtered("celu --size 32x16x1024") + compare_filtered("celu --size 32x16x1024 --cpu") + compare_filtered("log_sigmoid --size 32x16x1024") + compare_filtered("log_sigmoid --size 32x16x1024 --cpu") compare_filtered("scalar_mul --size 32x16x1024") compare_filtered("scalar_mul --size 32x16x1024 --cpu") compare_filtered("cross_entropy --size 256x1024") diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index 133a04494..e78f161a8 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -1,14 +1,26 @@ # Copyright © 2023 Apple Inc. from mlx.nn.layers.activations import ( + CELU, + ELU, GELU, + LeakyReLU, + LogSigmoid, ReLU, + ReLU6, SiLU, + Softplus, + celu, + elu, gelu, gelu_approx, gelu_fast_approx, + leaky_relu, + log_sigmoid, relu, + relu6, silu, + softplus, ) from mlx.nn.layers.base import Module from mlx.nn.layers.containers import Sequential diff --git a/python/mlx/nn/layers/activations.py b/python/mlx/nn/layers/activations.py index 106e11cbd..f924626a8 100644 --- a/python/mlx/nn/layers/activations.py +++ b/python/mlx/nn/layers/activations.py @@ -32,6 +32,47 @@ def relu(x): return mx.maximum(x, 0) +def leaky_relu(x, negative_slope=0.01): + """Applies the Leaky Rectified Linear Unit. + + Simply ``mx.maximum(negative_slope * x, x)``. + """ + return mx.maximum(negative_slope * x, x) + + +def elu(x, alpha=1.0): + """Applies the Exponential Linear Unit. + + Simply ``mx.where(x > 0, x, alpha * (mx.exp(x) - 1))``. + """ + return mx.where(x > 0, x, alpha * (mx.exp(x) - 1)) + + +def relu6(x): + r"""Applies the Rectified Linear Unit 6. + + Applies :math:`\min(\max(x, 0), 6)` element wise. + """ + return mx.minimum(mx.maximum(x, 0), 6.0) + + +def softplus(x): + r"""Applies the Softplus function. + + Applies :math:`\log(1 + \exp(x))` element wise. + """ + return mx.logaddexp(x, 0) + + +def celu(x, alpha=1.0): + r"""Applies the Continuously Differentiable Exponential Linear Unit. + + Applies :math:`\max(0, x) + \min(0, \alpha * (\exp(x / \alpha) - 1))` + element wise. + """ + return mx.maximum(x, 0.0) + alpha * (mx.exp(mx.minimum(x, 0.0) / alpha) - 1) + + def silu(x): r"""Applies the Sigmoid Linear Unit. @@ -41,6 +82,14 @@ def silu(x): return x * mx.sigmoid(x) +def log_sigmoid(x): + r"""Applies the Log Sigmoid function. + + Applies :math:`\log(\sigma(x)) = -\log(1 + e^{-x})` element wise. + """ + return -softplus(-x) + + def gelu(x): r"""Applies the Gaussian Error Linear Units function. @@ -99,11 +148,80 @@ class ReLU(Module): pass +class LeakyReLU(Module): + r"""Applies the Leaky Rectified Linear Unit. + + Simply ``mx.maximum(negative_slope * x, x)``. + + Args: + negative_slope: Controls the angle of the negative slope. Default: 1e-2. + """ + + def __init__(self, negative_slope=1e-2): + super().__init__() + self._negative_slope = negative_slope + + def __call__(self, x): + return leaky_relu(x, self._negative_slope) + + +class ELU(Module): + r"""Applies the Exponential Linear Unit. + Simply ``mx.where(x > 0, x, alpha * (mx.exp(x) - 1))``. + + See :func:`elu`, for the functional equivalent. + + Args: + alpha: the :math:`\alpha` value for the ELU formulation. Default: 1.0 + """ + + def __init__(self, alpha=1.0): + super().__init__() + self._alpha = alpha + + def __call__(self, x): + return elu(x, self._alpha) + + +@_make_activation_module(relu6) +class ReLU6(Module): + pass + + +@_make_activation_module(softplus) +class Softplus(Module): + pass + + +class CELU(Module): + r"""Applies the Continuously Differentiable Exponential Linear Unit. + Applies :math:`\max(0, x) + \min(0, \alpha * (\exp(x / \alpha) - 1))` + element wise. + + See :func:`celu`, for the functional equivalent. + + Args: + alpha: the :math:`\alpha` value for the CELU formulation. Default: 1.0 + """ + + def __init__(self, alpha=1.0): + super().__init__() + self._alpha = alpha + + def __call__(self, x): + return celu(x, self._alpha) + + @_make_activation_module(silu) class SiLU(Module): pass +@_make_activation_module(log_sigmoid) +class LogSigmoid(Module): + pass + + class GELU(Module): r"""Applies the Gaussian Error Linear Units. diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index e434c3ae8..707015814 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -303,6 +303,81 @@ class TestNN(mlx_tests.MLXTestCase): eq_tree = tree_map(mx.array_equal, m.parameters(), m_load.parameters()) self.assertTrue(all(tree_flatten(eq_tree))) + def test_relu(self): + x = mx.array([1.0, -1.0, 0.0]) + y = nn.relu(x) + self.assertTrue(mx.array_equal(y, mx.array([1.0, 0.0, 0.0]))) + self.assertEqual(y.shape, [3]) + self.assertEqual(y.dtype, mx.float32) + + def test_leaky_relu(self): + x = mx.array([1.0, -1.0, 0.0]) + y = nn.leaky_relu(x) + self.assertTrue(mx.array_equal(y, mx.array([1.0, -0.01, 0.0]))) + self.assertEqual(y.shape, [3]) + self.assertEqual(y.dtype, mx.float32) + + y = nn.LeakyReLU(negative_slope=0.1)(x) + self.assertTrue(mx.array_equal(y, mx.array([1.0, -0.1, 0.0]))) + self.assertEqual(y.shape, [3]) + self.assertEqual(y.dtype, mx.float32) + + def test_elu(self): + x = mx.array([1.0, -1.0, 0.0]) + y = nn.elu(x) + epsilon = 1e-4 + expected_y = mx.array([1.0, -0.6321, 0.0]) + self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon)) + self.assertEqual(y.shape, [3]) + self.assertEqual(y.dtype, mx.float32) + + y = nn.ELU(alpha=1.1)(x) + epsilon = 1e-4 + expected_y = mx.array([1.0, -0.6953, 0.0]) + self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon)) + self.assertEqual(y.shape, [3]) + self.assertEqual(y.dtype, mx.float32) + + def test_relu6(self): + x = mx.array([1.0, -1.0, 0.0, 7.0, -7.0]) + y = nn.relu6(x) + self.assertTrue(mx.array_equal(y, mx.array([1.0, 0.0, 0.0, 6.0, 0.0]))) + self.assertEqual(y.shape, [5]) + self.assertEqual(y.dtype, mx.float32) + + def test_softplus(self): + x = mx.array([1.0, -1.0, 0.0]) + y = nn.softplus(x) + epsilon = 1e-4 + expected_y = mx.array([1.3133, 0.3133, 0.6931]) + self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon)) + self.assertEqual(y.shape, [3]) + self.assertEqual(y.dtype, mx.float32) + + def test_celu(self): + x = mx.array([1.0, -1.0, 0.0]) + y = nn.celu(x) + epsilon = 1e-4 + expected_y = mx.array([1.0, -0.6321, 0.0]) + self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon)) + self.assertEqual(y.shape, [3]) + self.assertEqual(y.dtype, mx.float32) + + y = nn.CELU(alpha=1.1)(x) + expected_y = mx.array([1.0, -0.6568, 0.0]) + self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon)) + self.assertEqual(y.shape, [3]) + self.assertEqual(y.dtype, mx.float32) + + def test_log_sigmoid(self): + x = mx.array([1.0, -1.0, 0.0]) + y = nn.log_sigmoid(x) + epsilon = 1e-4 + expected_y = mx.array([-0.3133, -1.3133, -0.6931]) + self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon)) + self.assertEqual(y.shape, [3]) + self.assertEqual(y.dtype, mx.float32) + if __name__ == "__main__": unittest.main()