Adding Relu2 (#2582)

* in. com.

* upd. ackn.

* update __init__

* nits

* nits + format

* used mx.maximum(x, 0) instead of calling the function and moves relu6 under relu2 to make it nicer

* same with _make_activation_module

* Update python/mlx/nn/layers/activations.py

upd

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* update funct.rst

* upd. layers.rst

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
Gökdeniz Gülmez
2025-09-10 16:24:30 +02:00
committed by GitHub
parent 52b8384d10
commit db5443e831
5 changed files with 39 additions and 18 deletions

View File

@@ -19,7 +19,7 @@ MLX was developed with contributions from the following individuals:
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions. - Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
- Paul Paczuski: Improved stability of BCE loss calculation - Paul Paczuski: Improved stability of BCE loss calculation
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops. - Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer. - Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer, and the `ReLU²` activation function.
<a href="https://github.com/ml-explore/mlx/graphs/contributors"> <a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" /> <img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />

View File

@@ -27,6 +27,7 @@ simple functions.
mish mish
prelu prelu
relu relu
relu2
relu6 relu6
selu selu
sigmoid sigmoid

View File

@@ -50,6 +50,7 @@ Layers
QuantizedLinear QuantizedLinear
RMSNorm RMSNorm
ReLU ReLU
ReLU2
ReLU6 ReLU6
RNN RNN
RoPE RoPE

View File

@@ -15,6 +15,7 @@ from mlx.nn.layers.activations import (
Mish, Mish,
PReLU, PReLU,
ReLU, ReLU,
ReLU2,
ReLU6, ReLU6,
Sigmoid, Sigmoid,
SiLU, SiLU,
@@ -40,6 +41,7 @@ from mlx.nn.layers.activations import (
mish, mish,
prelu, prelu,
relu, relu,
relu2,
relu6, relu6,
selu, selu,
sigmoid, sigmoid,

View File

@@ -35,6 +35,24 @@ def relu(x):
return mx.maximum(x, 0) return mx.maximum(x, 0)
@partial(mx.compile, shapeless=True)
def relu2(x):
r"""Applies the ReLU² activation function.
Applies :math:`\max(0, x)^2` element wise.
"""
return mx.square(mx.maximum(x, 0))
@partial(mx.compile, shapeless=True)
def relu6(x):
r"""Applies the Rectified Linear Unit 6.
Applies :math:`\min(\max(x, 0), 6)` element wise.
"""
return mx.minimum(mx.maximum(x, 0), 6.0)
@partial(mx.compile, shapeless=True) @partial(mx.compile, shapeless=True)
def leaky_relu(x, negative_slope=0.01): def leaky_relu(x, negative_slope=0.01):
r"""Applies the Leaky Rectified Linear Unit. r"""Applies the Leaky Rectified Linear Unit.
@@ -62,15 +80,6 @@ def elu(x, alpha=1.0):
return mx.where(x > 0, x, alpha * (mx.exp(x) - 1)) return mx.where(x > 0, x, alpha * (mx.exp(x) - 1))
@partial(mx.compile, shapeless=True)
def relu6(x):
r"""Applies the Rectified Linear Unit 6.
Applies :math:`\min(\max(x, 0), 6)` element wise.
"""
return mx.minimum(mx.maximum(x, 0), 6.0)
@partial(mx.compile, shapeless=True) @partial(mx.compile, shapeless=True)
def softmax(x, axis=-1): def softmax(x, axis=-1):
r"""Applies the Softmax function. r"""Applies the Softmax function.
@@ -377,6 +386,22 @@ class ReLU(Module):
""" """
@_make_activation_module(relu2)
class ReLU2(Module):
r"""Applies the ReLU² activation function.
See :func:`relu2` for the functional equivalent.
"""
@_make_activation_module(relu6)
class ReLU6(Module):
r"""Applies the Rectified Linear Unit 6.
See :func:`relu6` for the functional equivalent.
"""
class LeakyReLU(Module): class LeakyReLU(Module):
r"""Applies the Leaky Rectified Linear Unit. r"""Applies the Leaky Rectified Linear Unit.
@@ -412,14 +437,6 @@ class ELU(Module):
return elu(x, self._alpha) return elu(x, self._alpha)
@_make_activation_module(relu6)
class ReLU6(Module):
r"""Applies the Rectified Linear Unit 6.
See :func:`relu6` for the functional equivalent.
"""
@_make_activation_module(softmax) @_make_activation_module(softmax)
class Softmax(Module): class Softmax(Module):
r"""Applies the Softmax function. r"""Applies the Softmax function.