mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 10:26:56 +08:00
Add softmin, hardshrink, hardtanh (#1180)
--------- Co-authored-by: Nikhil Mehta <nikmehta@tesla.com>
This commit is contained in:
@@ -6,7 +6,9 @@ from mlx.nn.layers.activations import (
|
||||
GELU,
|
||||
GLU,
|
||||
SELU,
|
||||
HardShrink,
|
||||
Hardswish,
|
||||
HardTanh,
|
||||
LeakyReLU,
|
||||
LogSigmoid,
|
||||
LogSoftmax,
|
||||
@@ -17,6 +19,7 @@ from mlx.nn.layers.activations import (
|
||||
Sigmoid,
|
||||
SiLU,
|
||||
Softmax,
|
||||
Softmin,
|
||||
Softplus,
|
||||
Softshrink,
|
||||
Softsign,
|
||||
@@ -28,6 +31,8 @@ from mlx.nn.layers.activations import (
|
||||
gelu_approx,
|
||||
gelu_fast_approx,
|
||||
glu,
|
||||
hard_shrink,
|
||||
hard_tanh,
|
||||
hardswish,
|
||||
leaky_relu,
|
||||
log_sigmoid,
|
||||
@@ -40,6 +45,7 @@ from mlx.nn.layers.activations import (
|
||||
sigmoid,
|
||||
silu,
|
||||
softmax,
|
||||
softmin,
|
||||
softplus,
|
||||
softshrink,
|
||||
softsign,
|
||||
|
@@ -286,6 +286,38 @@ def hardswish(x):
|
||||
return x * mx.minimum(max_x_3, 6) / 6
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def hard_tanh(x, min_val=-1.0, max_val=1.0):
|
||||
r"""Applies the HardTanh function.
|
||||
|
||||
Applies :math:`\max(\min(x, \text{max\_val}), \text{min\_val})` element-wise.
|
||||
"""
|
||||
return mx.minimum(mx.maximum(x, min_val), max_val)
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def hard_shrink(x, lambd=0.5):
|
||||
r"""Applies the HardShrink activation function.
|
||||
|
||||
.. math::
|
||||
\text{hardshrink}(x) = \begin{cases}
|
||||
x & \text{if } x > \lambda \\
|
||||
x & \text{if } x < -\lambda \\
|
||||
0 & \text{otherwise}
|
||||
\end{cases}
|
||||
"""
|
||||
return mx.where(mx.abs(x) > lambd, x, 0)
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def softmin(x, axis=-1):
|
||||
r"""Applies the Softmin function.
|
||||
|
||||
Applies :math:`\frac{e^{-x_i}}{\sum_j e^{-x_j}}` element-wise.
|
||||
"""
|
||||
return mx.softmax(-x, axis=axis)
|
||||
|
||||
|
||||
def tanh(x):
|
||||
"""Applies the hyperbolic tangent function.
|
||||
|
||||
@@ -579,3 +611,30 @@ class SELU(Module):
|
||||
|
||||
See :func:`selu` for the functional equivalent.
|
||||
"""
|
||||
|
||||
|
||||
@_make_activation_module(hard_tanh)
|
||||
class HardTanh(Module):
|
||||
r"""Applies the HardTanh function.
|
||||
|
||||
See :func:`hard_tanh` for the functional equivalent.
|
||||
"""
|
||||
|
||||
|
||||
@_make_activation_module(hard_shrink)
|
||||
class HardShrink(Module):
|
||||
r"""Applies the HardShrink function.
|
||||
|
||||
See :func:`hard_shrink` for the functional equivalent.
|
||||
|
||||
Args:
|
||||
lambd: the :math:`\lambda` value for Hardshrink. Default: ``0.5``
|
||||
"""
|
||||
|
||||
|
||||
@_make_activation_module(softmin)
|
||||
class Softmin(Module):
|
||||
r"""Applies the Softmin function.
|
||||
|
||||
See :func:`softmin` for the functional equivalent.
|
||||
"""
|
||||
|
Reference in New Issue
Block a user