mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 02:38:09 +08:00
Adding Relu2 (#2582)
* in. com. * upd. ackn. * update __init__ * nits * nits + format * used mx.maximum(x, 0) instead of calling the function and moves relu6 under relu2 to make it nicer * same with _make_activation_module * Update python/mlx/nn/layers/activations.py upd Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * update funct.rst * upd. layers.rst --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
@@ -19,7 +19,7 @@ MLX was developed with contributions from the following individuals:
|
|||||||
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
|
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
|
||||||
- Paul Paczuski: Improved stability of BCE loss calculation
|
- Paul Paczuski: Improved stability of BCE loss calculation
|
||||||
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
|
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
|
||||||
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer.
|
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer, and the `ReLU²` activation function.
|
||||||
|
|
||||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||||
|
@@ -27,6 +27,7 @@ simple functions.
|
|||||||
mish
|
mish
|
||||||
prelu
|
prelu
|
||||||
relu
|
relu
|
||||||
|
relu2
|
||||||
relu6
|
relu6
|
||||||
selu
|
selu
|
||||||
sigmoid
|
sigmoid
|
||||||
|
@@ -50,6 +50,7 @@ Layers
|
|||||||
QuantizedLinear
|
QuantizedLinear
|
||||||
RMSNorm
|
RMSNorm
|
||||||
ReLU
|
ReLU
|
||||||
|
ReLU2
|
||||||
ReLU6
|
ReLU6
|
||||||
RNN
|
RNN
|
||||||
RoPE
|
RoPE
|
||||||
|
@@ -15,6 +15,7 @@ from mlx.nn.layers.activations import (
|
|||||||
Mish,
|
Mish,
|
||||||
PReLU,
|
PReLU,
|
||||||
ReLU,
|
ReLU,
|
||||||
|
ReLU2,
|
||||||
ReLU6,
|
ReLU6,
|
||||||
Sigmoid,
|
Sigmoid,
|
||||||
SiLU,
|
SiLU,
|
||||||
@@ -40,6 +41,7 @@ from mlx.nn.layers.activations import (
|
|||||||
mish,
|
mish,
|
||||||
prelu,
|
prelu,
|
||||||
relu,
|
relu,
|
||||||
|
relu2,
|
||||||
relu6,
|
relu6,
|
||||||
selu,
|
selu,
|
||||||
sigmoid,
|
sigmoid,
|
||||||
|
@@ -35,6 +35,24 @@ def relu(x):
|
|||||||
return mx.maximum(x, 0)
|
return mx.maximum(x, 0)
|
||||||
|
|
||||||
|
|
||||||
|
@partial(mx.compile, shapeless=True)
|
||||||
|
def relu2(x):
|
||||||
|
r"""Applies the ReLU² activation function.
|
||||||
|
|
||||||
|
Applies :math:`\max(0, x)^2` element wise.
|
||||||
|
"""
|
||||||
|
return mx.square(mx.maximum(x, 0))
|
||||||
|
|
||||||
|
|
||||||
|
@partial(mx.compile, shapeless=True)
|
||||||
|
def relu6(x):
|
||||||
|
r"""Applies the Rectified Linear Unit 6.
|
||||||
|
|
||||||
|
Applies :math:`\min(\max(x, 0), 6)` element wise.
|
||||||
|
"""
|
||||||
|
return mx.minimum(mx.maximum(x, 0), 6.0)
|
||||||
|
|
||||||
|
|
||||||
@partial(mx.compile, shapeless=True)
|
@partial(mx.compile, shapeless=True)
|
||||||
def leaky_relu(x, negative_slope=0.01):
|
def leaky_relu(x, negative_slope=0.01):
|
||||||
r"""Applies the Leaky Rectified Linear Unit.
|
r"""Applies the Leaky Rectified Linear Unit.
|
||||||
@@ -62,15 +80,6 @@ def elu(x, alpha=1.0):
|
|||||||
return mx.where(x > 0, x, alpha * (mx.exp(x) - 1))
|
return mx.where(x > 0, x, alpha * (mx.exp(x) - 1))
|
||||||
|
|
||||||
|
|
||||||
@partial(mx.compile, shapeless=True)
|
|
||||||
def relu6(x):
|
|
||||||
r"""Applies the Rectified Linear Unit 6.
|
|
||||||
|
|
||||||
Applies :math:`\min(\max(x, 0), 6)` element wise.
|
|
||||||
"""
|
|
||||||
return mx.minimum(mx.maximum(x, 0), 6.0)
|
|
||||||
|
|
||||||
|
|
||||||
@partial(mx.compile, shapeless=True)
|
@partial(mx.compile, shapeless=True)
|
||||||
def softmax(x, axis=-1):
|
def softmax(x, axis=-1):
|
||||||
r"""Applies the Softmax function.
|
r"""Applies the Softmax function.
|
||||||
@@ -377,6 +386,22 @@ class ReLU(Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@_make_activation_module(relu2)
|
||||||
|
class ReLU2(Module):
|
||||||
|
r"""Applies the ReLU² activation function.
|
||||||
|
|
||||||
|
See :func:`relu2` for the functional equivalent.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@_make_activation_module(relu6)
|
||||||
|
class ReLU6(Module):
|
||||||
|
r"""Applies the Rectified Linear Unit 6.
|
||||||
|
|
||||||
|
See :func:`relu6` for the functional equivalent.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class LeakyReLU(Module):
|
class LeakyReLU(Module):
|
||||||
r"""Applies the Leaky Rectified Linear Unit.
|
r"""Applies the Leaky Rectified Linear Unit.
|
||||||
|
|
||||||
@@ -412,14 +437,6 @@ class ELU(Module):
|
|||||||
return elu(x, self._alpha)
|
return elu(x, self._alpha)
|
||||||
|
|
||||||
|
|
||||||
@_make_activation_module(relu6)
|
|
||||||
class ReLU6(Module):
|
|
||||||
r"""Applies the Rectified Linear Unit 6.
|
|
||||||
|
|
||||||
See :func:`relu6` for the functional equivalent.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
@_make_activation_module(softmax)
|
@_make_activation_module(softmax)
|
||||||
class Softmax(Module):
|
class Softmax(Module):
|
||||||
r"""Applies the Softmax function.
|
r"""Applies the Softmax function.
|
||||||
|
Reference in New Issue
Block a user