From 940f64fe6a8deda3f44141c19c450014d57ce44e Mon Sep 17 00:00:00 2001 From: John Mai Date: Sun, 15 Jun 2025 17:07:22 +0800 Subject: [PATCH 1/5] feat: Add ReLUSquared activation function --- python/mlx/nn/layers/__init__.py | 2 ++ python/mlx/nn/layers/activations.py | 23 +++++++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index 26f77917f..b434f00de 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -16,6 +16,7 @@ from mlx.nn.layers.activations import ( PReLU, ReLU, ReLU6, + ReLUSquared, Sigmoid, SiLU, Softmax, @@ -41,6 +42,7 @@ from mlx.nn.layers.activations import ( prelu, relu, relu6, + relu_squared, selu, sigmoid, silu, diff --git a/python/mlx/nn/layers/activations.py b/python/mlx/nn/layers/activations.py index 8eafd75d3..4f2f944b6 100644 --- a/python/mlx/nn/layers/activations.py +++ b/python/mlx/nn/layers/activations.py @@ -71,6 +71,17 @@ def relu6(x): return mx.minimum(mx.maximum(x, 0), 6.0) +@partial(mx.compile, shapeless=True) +def relu_squared(x): + r"""Applies the Rectified Linear Unit squared. + + Applies :math:`\max(x, 0)^2` element wise. + + Reference: https://arxiv.org/abs/2109.08668v2 + """ + return relu(x).square() + + @partial(mx.compile, shapeless=True) def softmax(x, axis=-1): r"""Applies the Softmax function. @@ -420,6 +431,18 @@ class ReLU6(Module): """ +@_make_activation_module(relu_squared) +class ReLUSquared(Module): + r"""Applies the Rectified Linear Unit squared. + + Applies :math:`\max(x, 0)^2` element wise. + + Reference: https://arxiv.org/abs/2109.08668v2 + + See :func:`relu_squared` for the functional equivalent. + """ + + @_make_activation_module(softmax) class Softmax(Module): r"""Applies the Softmax function. From cbd353bf73794c27cbcde672a7a32abbeaf80238 Mon Sep 17 00:00:00 2001 From: John Mai Date: Sun, 15 Jun 2025 17:07:33 +0800 Subject: [PATCH 2/5] test: Add unit test for ReLUSquared activation function --- python/tests/test_nn.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 13e31ad96..ca17b20be 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -855,6 +855,13 @@ class TestLayers(mlx_tests.MLXTestCase): self.assertEqual(y.shape, (3,)) self.assertEqual(y.dtype, mx.float32) + def test_relu_squared(self): + x = mx.array([-1.0, 0.0, 1.0, 2.0, 3.0]) + y = nn.relu_squared(x) + self.assertTrue(mx.array_equal(y, mx.array([0.0, 0.0, 1.0, 4.0, 9.0]))) + self.assertEqual(y.shape, (5,)) + self.assertEqual(y.dtype, mx.float32) + def test_leaky_relu(self): x = mx.array([1.0, -1.0, 0.0]) y = nn.leaky_relu(x) From fe0672a9d24d81dcf85a3ce48ce60157d4a86fdf Mon Sep 17 00:00:00 2001 From: John Mai Date: Sun, 15 Jun 2025 17:33:58 +0800 Subject: [PATCH 3/5] docs: Update documentation to include ReLUSquared activation function --- docs/src/python/nn/functions.rst | 1 + docs/src/python/nn/layers.rst | 1 + 2 files changed, 2 insertions(+) diff --git a/docs/src/python/nn/functions.rst b/docs/src/python/nn/functions.rst index 9b6cd9f62..325a24ef2 100644 --- a/docs/src/python/nn/functions.rst +++ b/docs/src/python/nn/functions.rst @@ -28,6 +28,7 @@ simple functions. prelu relu relu6 + relu_squared selu sigmoid silu diff --git a/docs/src/python/nn/layers.rst b/docs/src/python/nn/layers.rst index 4eb14b088..218f7f457 100644 --- a/docs/src/python/nn/layers.rst +++ b/docs/src/python/nn/layers.rst @@ -51,6 +51,7 @@ Layers RMSNorm ReLU ReLU6 + ReLUSquared RNN RoPE SELU From 989e8bab66c07ae3d5481fe282b4bcaa256207b8 Mon Sep 17 00:00:00 2001 From: John Mai Date: Sun, 15 Jun 2025 17:34:10 +0800 Subject: [PATCH 4/5] feat: Add benchmarking for ReLUSquared activation function --- benchmarks/python/comparative/bench_mlx.py | 8 ++++++++ benchmarks/python/comparative/bench_torch.py | 10 ++++++++++ benchmarks/python/comparative/compare.py | 2 ++ 3 files changed, 20 insertions(+) diff --git a/benchmarks/python/comparative/bench_mlx.py b/benchmarks/python/comparative/bench_mlx.py index 0821ccae6..4fe29a991 100644 --- a/benchmarks/python/comparative/bench_mlx.py +++ b/benchmarks/python/comparative/bench_mlx.py @@ -223,6 +223,11 @@ def relu6(x): y = nn.relu6(y) mx.eval(y) +def relu_squared(x): + y = x + for i in range(100): + y = nn.relu_squared(y) + mx.eval(y) def softplus(x): y = x @@ -458,6 +463,9 @@ if __name__ == "__main__": elif args.benchmark == "relu6": print(bench(relu6, x)) + elif args.benchmark == "relu_squared": + print(bench(relu_squared, x)) + elif args.benchmark == "celu": print(bench(celu, x)) diff --git a/benchmarks/python/comparative/bench_torch.py b/benchmarks/python/comparative/bench_torch.py index a2157707b..c9a65a819 100644 --- a/benchmarks/python/comparative/bench_torch.py +++ b/benchmarks/python/comparative/bench_torch.py @@ -156,6 +156,13 @@ def relu6(x): y = torch.nn.functional.relu6(y) sync_if_needed(x) +@torch.no_grad() +def relu_squared(x): + y = x + for i in range(100): + y = torch.nn.functional.relu(y) + y = torch.square(y) + sync_if_needed(x) @torch.no_grad() def softplus(x): @@ -407,6 +414,9 @@ if __name__ == "__main__": elif args.benchmark == "relu6": print(bench(relu6, x)) + elif args.benchmark == "relu_squared": + print(bench(relu_squared, x)) + elif args.benchmark == "softplus": print(bench(softplus, x)) diff --git a/benchmarks/python/comparative/compare.py b/benchmarks/python/comparative/compare.py index 68b4a5bd3..e3270e903 100644 --- a/benchmarks/python/comparative/compare.py +++ b/benchmarks/python/comparative/compare.py @@ -207,6 +207,8 @@ if __name__ == "__main__": compare_filtered("elu --size 32x16x1024 --cpu") compare_filtered("relu6 --size 32x16x1024") compare_filtered("relu6 --size 32x16x1024 --cpu") + compare_filtered("relu_squared --size 32x16x1024") + compare_filtered("relu_squared --size 32x16x1024 --cpu") compare_filtered("softplus --size 32x16x1024") compare_filtered("softplus --size 32x16x1024 --cpu") compare_filtered("celu --size 32x16x1024") From b3c1aaafd25e5b270127ab1940648c8d92dc4dbb Mon Sep 17 00:00:00 2001 From: John Mai Date: Sun, 15 Jun 2025 17:35:33 +0800 Subject: [PATCH 5/5] update: format code --- benchmarks/python/comparative/bench_mlx.py | 2 ++ benchmarks/python/comparative/bench_torch.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/benchmarks/python/comparative/bench_mlx.py b/benchmarks/python/comparative/bench_mlx.py index 4fe29a991..658cc6f95 100644 --- a/benchmarks/python/comparative/bench_mlx.py +++ b/benchmarks/python/comparative/bench_mlx.py @@ -223,12 +223,14 @@ def relu6(x): y = nn.relu6(y) mx.eval(y) + def relu_squared(x): y = x for i in range(100): y = nn.relu_squared(y) mx.eval(y) + def softplus(x): y = x for i in range(100): diff --git a/benchmarks/python/comparative/bench_torch.py b/benchmarks/python/comparative/bench_torch.py index c9a65a819..6c5c518f6 100644 --- a/benchmarks/python/comparative/bench_torch.py +++ b/benchmarks/python/comparative/bench_torch.py @@ -156,6 +156,7 @@ def relu6(x): y = torch.nn.functional.relu6(y) sync_if_needed(x) + @torch.no_grad() def relu_squared(x): y = x @@ -164,6 +165,7 @@ def relu_squared(x): y = torch.square(y) sync_if_needed(x) + @torch.no_grad() def softplus(x): y = x