Adding Relu2 (#2582)

* in. com.

* upd. ackn.

* update __init__

* nits

* nits + format

* used mx.maximum(x, 0) instead of calling the function and moves relu6 under relu2 to make it nicer

* same with _make_activation_module

* Update python/mlx/nn/layers/activations.py

upd

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* update funct.rst

* upd. layers.rst

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
Gökdeniz Gülmez
2025-09-10 16:24:30 +02:00
committed by GitHub
parent 52b8384d10
commit db5443e831
5 changed files with 39 additions and 18 deletions

View File

@@ -19,7 +19,7 @@ MLX was developed with contributions from the following individuals:
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
- Paul Paczuski: Improved stability of BCE loss calculation
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer.
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer, and the `ReLU²` activation function.
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />

View File

@@ -27,6 +27,7 @@ simple functions.
mish
prelu
relu
relu2
relu6
selu
sigmoid

View File

@@ -50,6 +50,7 @@ Layers
QuantizedLinear
RMSNorm
ReLU
ReLU2
ReLU6
RNN
RoPE

View File

@@ -15,6 +15,7 @@ from mlx.nn.layers.activations import (
Mish,
PReLU,
ReLU,
ReLU2,
ReLU6,
Sigmoid,
SiLU,
@@ -40,6 +41,7 @@ from mlx.nn.layers.activations import (
mish,
prelu,
relu,
relu2,
relu6,
selu,
sigmoid,

View File

@@ -35,6 +35,24 @@ def relu(x):
return mx.maximum(x, 0)
@partial(mx.compile, shapeless=True)
def relu2(x):
r"""Applies the ReLU² activation function.
Applies :math:`\max(0, x)^2` element wise.
"""
return mx.square(mx.maximum(x, 0))
@partial(mx.compile, shapeless=True)
def relu6(x):
r"""Applies the Rectified Linear Unit 6.
Applies :math:`\min(\max(x, 0), 6)` element wise.
"""
return mx.minimum(mx.maximum(x, 0), 6.0)
@partial(mx.compile, shapeless=True)
def leaky_relu(x, negative_slope=0.01):
r"""Applies the Leaky Rectified Linear Unit.
@@ -62,15 +80,6 @@ def elu(x, alpha=1.0):
return mx.where(x > 0, x, alpha * (mx.exp(x) - 1))
@partial(mx.compile, shapeless=True)
def relu6(x):
r"""Applies the Rectified Linear Unit 6.
Applies :math:`\min(\max(x, 0), 6)` element wise.
"""
return mx.minimum(mx.maximum(x, 0), 6.0)
@partial(mx.compile, shapeless=True)
def softmax(x, axis=-1):
r"""Applies the Softmax function.
@@ -377,6 +386,22 @@ class ReLU(Module):
"""
@_make_activation_module(relu2)
class ReLU2(Module):
r"""Applies the ReLU² activation function.
See :func:`relu2` for the functional equivalent.
"""
@_make_activation_module(relu6)
class ReLU6(Module):
r"""Applies the Rectified Linear Unit 6.
See :func:`relu6` for the functional equivalent.
"""
class LeakyReLU(Module):
r"""Applies the Leaky Rectified Linear Unit.
@@ -412,14 +437,6 @@ class ELU(Module):
return elu(x, self._alpha)
@_make_activation_module(relu6)
class ReLU6(Module):
r"""Applies the Rectified Linear Unit 6.
See :func:`relu6` for the functional equivalent.
"""
@_make_activation_module(softmax)
class Softmax(Module):
r"""Applies the Softmax function.