Adding Relu2 (#2582)

* in. com.

* upd. ackn.

* update __init__

* nits

* nits + format

* used mx.maximum(x, 0) instead of calling the function and moves relu6 under relu2 to make it nicer

* same with _make_activation_module

* Update python/mlx/nn/layers/activations.py

upd

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* update funct.rst

* upd. layers.rst

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
Gökdeniz Gülmez
2025-09-10 16:24:30 +02:00
committed by GitHub
parent 52b8384d10
commit db5443e831
5 changed files with 39 additions and 18 deletions

View File

@@ -15,6 +15,7 @@ from mlx.nn.layers.activations import (
Mish,
PReLU,
ReLU,
ReLU2,
ReLU6,
Sigmoid,
SiLU,
@@ -40,6 +41,7 @@ from mlx.nn.layers.activations import (
mish,
prelu,
relu,
relu2,
relu6,
selu,
sigmoid,

View File

@@ -35,6 +35,24 @@ def relu(x):
return mx.maximum(x, 0)
@partial(mx.compile, shapeless=True)
def relu2(x):
r"""Applies the ReLU² activation function.
Applies :math:`\max(0, x)^2` element wise.
"""
return mx.square(mx.maximum(x, 0))
@partial(mx.compile, shapeless=True)
def relu6(x):
r"""Applies the Rectified Linear Unit 6.
Applies :math:`\min(\max(x, 0), 6)` element wise.
"""
return mx.minimum(mx.maximum(x, 0), 6.0)
@partial(mx.compile, shapeless=True)
def leaky_relu(x, negative_slope=0.01):
r"""Applies the Leaky Rectified Linear Unit.
@@ -62,15 +80,6 @@ def elu(x, alpha=1.0):
return mx.where(x > 0, x, alpha * (mx.exp(x) - 1))
@partial(mx.compile, shapeless=True)
def relu6(x):
r"""Applies the Rectified Linear Unit 6.
Applies :math:`\min(\max(x, 0), 6)` element wise.
"""
return mx.minimum(mx.maximum(x, 0), 6.0)
@partial(mx.compile, shapeless=True)
def softmax(x, axis=-1):
r"""Applies the Softmax function.
@@ -377,6 +386,22 @@ class ReLU(Module):
"""
@_make_activation_module(relu2)
class ReLU2(Module):
r"""Applies the ReLU² activation function.
See :func:`relu2` for the functional equivalent.
"""
@_make_activation_module(relu6)
class ReLU6(Module):
r"""Applies the Rectified Linear Unit 6.
See :func:`relu6` for the functional equivalent.
"""
class LeakyReLU(Module):
r"""Applies the Leaky Rectified Linear Unit.
@@ -412,14 +437,6 @@ class ELU(Module):
return elu(x, self._alpha)
@_make_activation_module(relu6)
class ReLU6(Module):
r"""Applies the Rectified Linear Unit 6.
See :func:`relu6` for the functional equivalent.
"""
@_make_activation_module(softmax)
class Softmax(Module):
r"""Applies the Softmax function.