mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-08 18:16:41 +08:00
feat: Add ReLUSquared activation function
This commit is contained in:
parent
a14aaa7c9d
commit
940f64fe6a
@ -16,6 +16,7 @@ from mlx.nn.layers.activations import (
|
|||||||
PReLU,
|
PReLU,
|
||||||
ReLU,
|
ReLU,
|
||||||
ReLU6,
|
ReLU6,
|
||||||
|
ReLUSquared,
|
||||||
Sigmoid,
|
Sigmoid,
|
||||||
SiLU,
|
SiLU,
|
||||||
Softmax,
|
Softmax,
|
||||||
@ -41,6 +42,7 @@ from mlx.nn.layers.activations import (
|
|||||||
prelu,
|
prelu,
|
||||||
relu,
|
relu,
|
||||||
relu6,
|
relu6,
|
||||||
|
relu_squared,
|
||||||
selu,
|
selu,
|
||||||
sigmoid,
|
sigmoid,
|
||||||
silu,
|
silu,
|
||||||
|
@ -71,6 +71,17 @@ def relu6(x):
|
|||||||
return mx.minimum(mx.maximum(x, 0), 6.0)
|
return mx.minimum(mx.maximum(x, 0), 6.0)
|
||||||
|
|
||||||
|
|
||||||
|
@partial(mx.compile, shapeless=True)
|
||||||
|
def relu_squared(x):
|
||||||
|
r"""Applies the Rectified Linear Unit squared.
|
||||||
|
|
||||||
|
Applies :math:`\max(x, 0)^2` element wise.
|
||||||
|
|
||||||
|
Reference: https://arxiv.org/abs/2109.08668v2
|
||||||
|
"""
|
||||||
|
return relu(x).square()
|
||||||
|
|
||||||
|
|
||||||
@partial(mx.compile, shapeless=True)
|
@partial(mx.compile, shapeless=True)
|
||||||
def softmax(x, axis=-1):
|
def softmax(x, axis=-1):
|
||||||
r"""Applies the Softmax function.
|
r"""Applies the Softmax function.
|
||||||
@ -420,6 +431,18 @@ class ReLU6(Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@_make_activation_module(relu_squared)
|
||||||
|
class ReLUSquared(Module):
|
||||||
|
r"""Applies the Rectified Linear Unit squared.
|
||||||
|
|
||||||
|
Applies :math:`\max(x, 0)^2` element wise.
|
||||||
|
|
||||||
|
Reference: https://arxiv.org/abs/2109.08668v2
|
||||||
|
|
||||||
|
See :func:`relu_squared` for the functional equivalent.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
@_make_activation_module(softmax)
|
@_make_activation_module(softmax)
|
||||||
class Softmax(Module):
|
class Softmax(Module):
|
||||||
r"""Applies the Softmax function.
|
r"""Applies the Softmax function.
|
||||||
|
Loading…
Reference in New Issue
Block a user