From db5443e83107adcf5bbf44d15cf3a0f1b76fb8fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=B6kdeniz=20G=C3=BClmez?= <60228478+Goekdeniz-Guelmez@users.noreply.github.com> Date: Wed, 10 Sep 2025 16:24:30 +0200 Subject: [PATCH] Adding Relu2 (#2582) * in. com. * upd. ackn. * update __init__ * nits * nits + format * used mx.maximum(x, 0) instead of calling the function and moves relu6 under relu2 to make it nicer * same with _make_activation_module * Update python/mlx/nn/layers/activations.py upd Co-authored-by: Awni Hannun * update funct.rst * upd. layers.rst --------- Co-authored-by: Awni Hannun --- ACKNOWLEDGMENTS.md | 2 +- docs/src/python/nn/functions.rst | 1 + docs/src/python/nn/layers.rst | 1 + python/mlx/nn/layers/__init__.py | 2 ++ python/mlx/nn/layers/activations.py | 51 +++++++++++++++++++---------- 5 files changed, 39 insertions(+), 18 deletions(-) diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index 6e479e5a6..186908f09 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -19,7 +19,7 @@ MLX was developed with contributions from the following individuals: - Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions. - Paul Paczuski: Improved stability of BCE loss calculation - Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops. -- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer. +- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer, and the `ReLU²` activation function. diff --git a/docs/src/python/nn/functions.rst b/docs/src/python/nn/functions.rst index 9b6cd9f62..d3c533d0c 100644 --- a/docs/src/python/nn/functions.rst +++ b/docs/src/python/nn/functions.rst @@ -27,6 +27,7 @@ simple functions. mish prelu relu + relu2 relu6 selu sigmoid diff --git a/docs/src/python/nn/layers.rst b/docs/src/python/nn/layers.rst index 4eb14b088..146948147 100644 --- a/docs/src/python/nn/layers.rst +++ b/docs/src/python/nn/layers.rst @@ -50,6 +50,7 @@ Layers QuantizedLinear RMSNorm ReLU + ReLU2 ReLU6 RNN RoPE diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index 26f77917f..ea2d3029d 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -15,6 +15,7 @@ from mlx.nn.layers.activations import ( Mish, PReLU, ReLU, + ReLU2, ReLU6, Sigmoid, SiLU, @@ -40,6 +41,7 @@ from mlx.nn.layers.activations import ( mish, prelu, relu, + relu2, relu6, selu, sigmoid, diff --git a/python/mlx/nn/layers/activations.py b/python/mlx/nn/layers/activations.py index 21994c0e6..360bb113d 100644 --- a/python/mlx/nn/layers/activations.py +++ b/python/mlx/nn/layers/activations.py @@ -35,6 +35,24 @@ def relu(x): return mx.maximum(x, 0) +@partial(mx.compile, shapeless=True) +def relu2(x): + r"""Applies the ReLU² activation function. + + Applies :math:`\max(0, x)^2` element wise. + """ + return mx.square(mx.maximum(x, 0)) + + +@partial(mx.compile, shapeless=True) +def relu6(x): + r"""Applies the Rectified Linear Unit 6. + + Applies :math:`\min(\max(x, 0), 6)` element wise. + """ + return mx.minimum(mx.maximum(x, 0), 6.0) + + @partial(mx.compile, shapeless=True) def leaky_relu(x, negative_slope=0.01): r"""Applies the Leaky Rectified Linear Unit. @@ -62,15 +80,6 @@ def elu(x, alpha=1.0): return mx.where(x > 0, x, alpha * (mx.exp(x) - 1)) -@partial(mx.compile, shapeless=True) -def relu6(x): - r"""Applies the Rectified Linear Unit 6. - - Applies :math:`\min(\max(x, 0), 6)` element wise. - """ - return mx.minimum(mx.maximum(x, 0), 6.0) - - @partial(mx.compile, shapeless=True) def softmax(x, axis=-1): r"""Applies the Softmax function. @@ -377,6 +386,22 @@ class ReLU(Module): """ +@_make_activation_module(relu2) +class ReLU2(Module): + r"""Applies the ReLU² activation function. + + See :func:`relu2` for the functional equivalent. + """ + + +@_make_activation_module(relu6) +class ReLU6(Module): + r"""Applies the Rectified Linear Unit 6. + + See :func:`relu6` for the functional equivalent. + """ + + class LeakyReLU(Module): r"""Applies the Leaky Rectified Linear Unit. @@ -412,14 +437,6 @@ class ELU(Module): return elu(x, self._alpha) -@_make_activation_module(relu6) -class ReLU6(Module): - r"""Applies the Rectified Linear Unit 6. - - See :func:`relu6` for the functional equivalent. - """ - - @_make_activation_module(softmax) class Softmax(Module): r"""Applies the Softmax function.