From 940f64fe6a8deda3f44141c19c450014d57ce44e Mon Sep 17 00:00:00 2001 From: John Mai Date: Sun, 15 Jun 2025 17:07:22 +0800 Subject: [PATCH] feat: Add ReLUSquared activation function --- python/mlx/nn/layers/__init__.py | 2 ++ python/mlx/nn/layers/activations.py | 23 +++++++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index 26f77917fc..b434f00dec 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -16,6 +16,7 @@ from mlx.nn.layers.activations import ( PReLU, ReLU, ReLU6, + ReLUSquared, Sigmoid, SiLU, Softmax, @@ -41,6 +42,7 @@ from mlx.nn.layers.activations import ( prelu, relu, relu6, + relu_squared, selu, sigmoid, silu, diff --git a/python/mlx/nn/layers/activations.py b/python/mlx/nn/layers/activations.py index 8eafd75d3a..4f2f944b6c 100644 --- a/python/mlx/nn/layers/activations.py +++ b/python/mlx/nn/layers/activations.py @@ -71,6 +71,17 @@ def relu6(x): return mx.minimum(mx.maximum(x, 0), 6.0) +@partial(mx.compile, shapeless=True) +def relu_squared(x): + r"""Applies the Rectified Linear Unit squared. + + Applies :math:`\max(x, 0)^2` element wise. + + Reference: https://arxiv.org/abs/2109.08668v2 + """ + return relu(x).square() + + @partial(mx.compile, shapeless=True) def softmax(x, axis=-1): r"""Applies the Softmax function. @@ -420,6 +431,18 @@ class ReLU6(Module): """ +@_make_activation_module(relu_squared) +class ReLUSquared(Module): + r"""Applies the Rectified Linear Unit squared. + + Applies :math:`\max(x, 0)^2` element wise. + + Reference: https://arxiv.org/abs/2109.08668v2 + + See :func:`relu_squared` for the functional equivalent. + """ + + @_make_activation_module(softmax) class Softmax(Module): r"""Applies the Softmax function.