mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 18:28:12 +08:00
Adding Relu2 (#2582)
* in. com. * upd. ackn. * update __init__ * nits * nits + format * used mx.maximum(x, 0) instead of calling the function and moves relu6 under relu2 to make it nicer * same with _make_activation_module * Update python/mlx/nn/layers/activations.py upd Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * update funct.rst * upd. layers.rst --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
@@ -19,7 +19,7 @@ MLX was developed with contributions from the following individuals:
|
||||
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
|
||||
- Paul Paczuski: Improved stability of BCE loss calculation
|
||||
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
|
||||
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer.
|
||||
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer, and the `ReLU²` activation function.
|
||||
|
||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||
|
@@ -27,6 +27,7 @@ simple functions.
|
||||
mish
|
||||
prelu
|
||||
relu
|
||||
relu2
|
||||
relu6
|
||||
selu
|
||||
sigmoid
|
||||
|
@@ -50,6 +50,7 @@ Layers
|
||||
QuantizedLinear
|
||||
RMSNorm
|
||||
ReLU
|
||||
ReLU2
|
||||
ReLU6
|
||||
RNN
|
||||
RoPE
|
||||
|
@@ -15,6 +15,7 @@ from mlx.nn.layers.activations import (
|
||||
Mish,
|
||||
PReLU,
|
||||
ReLU,
|
||||
ReLU2,
|
||||
ReLU6,
|
||||
Sigmoid,
|
||||
SiLU,
|
||||
@@ -40,6 +41,7 @@ from mlx.nn.layers.activations import (
|
||||
mish,
|
||||
prelu,
|
||||
relu,
|
||||
relu2,
|
||||
relu6,
|
||||
selu,
|
||||
sigmoid,
|
||||
|
@@ -35,6 +35,24 @@ def relu(x):
|
||||
return mx.maximum(x, 0)
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def relu2(x):
|
||||
r"""Applies the ReLU² activation function.
|
||||
|
||||
Applies :math:`\max(0, x)^2` element wise.
|
||||
"""
|
||||
return mx.square(mx.maximum(x, 0))
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def relu6(x):
|
||||
r"""Applies the Rectified Linear Unit 6.
|
||||
|
||||
Applies :math:`\min(\max(x, 0), 6)` element wise.
|
||||
"""
|
||||
return mx.minimum(mx.maximum(x, 0), 6.0)
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def leaky_relu(x, negative_slope=0.01):
|
||||
r"""Applies the Leaky Rectified Linear Unit.
|
||||
@@ -62,15 +80,6 @@ def elu(x, alpha=1.0):
|
||||
return mx.where(x > 0, x, alpha * (mx.exp(x) - 1))
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def relu6(x):
|
||||
r"""Applies the Rectified Linear Unit 6.
|
||||
|
||||
Applies :math:`\min(\max(x, 0), 6)` element wise.
|
||||
"""
|
||||
return mx.minimum(mx.maximum(x, 0), 6.0)
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def softmax(x, axis=-1):
|
||||
r"""Applies the Softmax function.
|
||||
@@ -377,6 +386,22 @@ class ReLU(Module):
|
||||
"""
|
||||
|
||||
|
||||
@_make_activation_module(relu2)
|
||||
class ReLU2(Module):
|
||||
r"""Applies the ReLU² activation function.
|
||||
|
||||
See :func:`relu2` for the functional equivalent.
|
||||
"""
|
||||
|
||||
|
||||
@_make_activation_module(relu6)
|
||||
class ReLU6(Module):
|
||||
r"""Applies the Rectified Linear Unit 6.
|
||||
|
||||
See :func:`relu6` for the functional equivalent.
|
||||
"""
|
||||
|
||||
|
||||
class LeakyReLU(Module):
|
||||
r"""Applies the Leaky Rectified Linear Unit.
|
||||
|
||||
@@ -412,14 +437,6 @@ class ELU(Module):
|
||||
return elu(x, self._alpha)
|
||||
|
||||
|
||||
@_make_activation_module(relu6)
|
||||
class ReLU6(Module):
|
||||
r"""Applies the Rectified Linear Unit 6.
|
||||
|
||||
See :func:`relu6` for the functional equivalent.
|
||||
"""
|
||||
|
||||
|
||||
@_make_activation_module(softmax)
|
||||
class Softmax(Module):
|
||||
r"""Applies the Softmax function.
|
||||
|
Reference in New Issue
Block a user