diff --git a/benchmarks/python/comparative/bench_mlx.py b/benchmarks/python/comparative/bench_mlx.py index 0821ccae6..658cc6f95 100644 --- a/benchmarks/python/comparative/bench_mlx.py +++ b/benchmarks/python/comparative/bench_mlx.py @@ -224,6 +224,13 @@ def relu6(x): mx.eval(y) +def relu_squared(x): + y = x + for i in range(100): + y = nn.relu_squared(y) + mx.eval(y) + + def softplus(x): y = x for i in range(100): @@ -458,6 +465,9 @@ if __name__ == "__main__": elif args.benchmark == "relu6": print(bench(relu6, x)) + elif args.benchmark == "relu_squared": + print(bench(relu_squared, x)) + elif args.benchmark == "celu": print(bench(celu, x)) diff --git a/benchmarks/python/comparative/bench_torch.py b/benchmarks/python/comparative/bench_torch.py index a2157707b..6c5c518f6 100644 --- a/benchmarks/python/comparative/bench_torch.py +++ b/benchmarks/python/comparative/bench_torch.py @@ -157,6 +157,15 @@ def relu6(x): sync_if_needed(x) +@torch.no_grad() +def relu_squared(x): + y = x + for i in range(100): + y = torch.nn.functional.relu(y) + y = torch.square(y) + sync_if_needed(x) + + @torch.no_grad() def softplus(x): y = x @@ -407,6 +416,9 @@ if __name__ == "__main__": elif args.benchmark == "relu6": print(bench(relu6, x)) + elif args.benchmark == "relu_squared": + print(bench(relu_squared, x)) + elif args.benchmark == "softplus": print(bench(softplus, x)) diff --git a/benchmarks/python/comparative/compare.py b/benchmarks/python/comparative/compare.py index 68b4a5bd3..e3270e903 100644 --- a/benchmarks/python/comparative/compare.py +++ b/benchmarks/python/comparative/compare.py @@ -207,6 +207,8 @@ if __name__ == "__main__": compare_filtered("elu --size 32x16x1024 --cpu") compare_filtered("relu6 --size 32x16x1024") compare_filtered("relu6 --size 32x16x1024 --cpu") + compare_filtered("relu_squared --size 32x16x1024") + compare_filtered("relu_squared --size 32x16x1024 --cpu") compare_filtered("softplus --size 32x16x1024") compare_filtered("softplus --size 32x16x1024 --cpu") compare_filtered("celu --size 32x16x1024") diff --git a/docs/src/python/nn/functions.rst b/docs/src/python/nn/functions.rst index 9b6cd9f62..325a24ef2 100644 --- a/docs/src/python/nn/functions.rst +++ b/docs/src/python/nn/functions.rst @@ -28,6 +28,7 @@ simple functions. prelu relu relu6 + relu_squared selu sigmoid silu diff --git a/docs/src/python/nn/layers.rst b/docs/src/python/nn/layers.rst index 4eb14b088..218f7f457 100644 --- a/docs/src/python/nn/layers.rst +++ b/docs/src/python/nn/layers.rst @@ -51,6 +51,7 @@ Layers RMSNorm ReLU ReLU6 + ReLUSquared RNN RoPE SELU diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index 26f77917f..b434f00de 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -16,6 +16,7 @@ from mlx.nn.layers.activations import ( PReLU, ReLU, ReLU6, + ReLUSquared, Sigmoid, SiLU, Softmax, @@ -41,6 +42,7 @@ from mlx.nn.layers.activations import ( prelu, relu, relu6, + relu_squared, selu, sigmoid, silu, diff --git a/python/mlx/nn/layers/activations.py b/python/mlx/nn/layers/activations.py index 8eafd75d3..4f2f944b6 100644 --- a/python/mlx/nn/layers/activations.py +++ b/python/mlx/nn/layers/activations.py @@ -71,6 +71,17 @@ def relu6(x): return mx.minimum(mx.maximum(x, 0), 6.0) +@partial(mx.compile, shapeless=True) +def relu_squared(x): + r"""Applies the Rectified Linear Unit squared. + + Applies :math:`\max(x, 0)^2` element wise. + + Reference: https://arxiv.org/abs/2109.08668v2 + """ + return relu(x).square() + + @partial(mx.compile, shapeless=True) def softmax(x, axis=-1): r"""Applies the Softmax function. @@ -420,6 +431,18 @@ class ReLU6(Module): """ +@_make_activation_module(relu_squared) +class ReLUSquared(Module): + r"""Applies the Rectified Linear Unit squared. + + Applies :math:`\max(x, 0)^2` element wise. + + Reference: https://arxiv.org/abs/2109.08668v2 + + See :func:`relu_squared` for the functional equivalent. + """ + + @_make_activation_module(softmax) class Softmax(Module): r"""Applies the Softmax function. diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 7753224b3..a65887ab5 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -860,6 +860,13 @@ class TestLayers(mlx_tests.MLXTestCase): self.assertEqual(y.shape, (3,)) self.assertEqual(y.dtype, mx.float32) + def test_relu_squared(self): + x = mx.array([-1.0, 0.0, 1.0, 2.0, 3.0]) + y = nn.relu_squared(x) + self.assertTrue(mx.array_equal(y, mx.array([0.0, 0.0, 1.0, 4.0, 9.0]))) + self.assertEqual(y.shape, (5,)) + self.assertEqual(y.dtype, mx.float32) + def test_leaky_relu(self): x = mx.array([1.0, -1.0, 0.0]) y = nn.leaky_relu(x)