feat: Add ReLUSquared activation function

This commit is contained in:
John Mai 2025-06-15 17:07:22 +08:00
parent a14aaa7c9d
commit 940f64fe6a
2 changed files with 25 additions and 0 deletions

View File

@ -16,6 +16,7 @@ from mlx.nn.layers.activations import (
PReLU, PReLU,
ReLU, ReLU,
ReLU6, ReLU6,
ReLUSquared,
Sigmoid, Sigmoid,
SiLU, SiLU,
Softmax, Softmax,
@ -41,6 +42,7 @@ from mlx.nn.layers.activations import (
prelu, prelu,
relu, relu,
relu6, relu6,
relu_squared,
selu, selu,
sigmoid, sigmoid,
silu, silu,

View File

@ -71,6 +71,17 @@ def relu6(x):
return mx.minimum(mx.maximum(x, 0), 6.0) return mx.minimum(mx.maximum(x, 0), 6.0)
@partial(mx.compile, shapeless=True)
def relu_squared(x):
r"""Applies the Rectified Linear Unit squared.
Applies :math:`\max(x, 0)^2` element wise.
Reference: https://arxiv.org/abs/2109.08668v2
"""
return relu(x).square()
@partial(mx.compile, shapeless=True) @partial(mx.compile, shapeless=True)
def softmax(x, axis=-1): def softmax(x, axis=-1):
r"""Applies the Softmax function. r"""Applies the Softmax function.
@ -420,6 +431,18 @@ class ReLU6(Module):
""" """
@_make_activation_module(relu_squared)
class ReLUSquared(Module):
r"""Applies the Rectified Linear Unit squared.
Applies :math:`\max(x, 0)^2` element wise.
Reference: https://arxiv.org/abs/2109.08668v2
See :func:`relu_squared` for the functional equivalent.
"""
@_make_activation_module(softmax) @_make_activation_module(softmax)
class Softmax(Module): class Softmax(Module):
r"""Applies the Softmax function. r"""Applies the Softmax function.