Add softmin, hardshrink, hardtanh (#1180)

---------

Co-authored-by: Nikhil Mehta <nikmehta@tesla.com>
This commit is contained in:
Nikhil Mehta
2024-06-04 15:48:18 -07:00
committed by GitHub
parent 83b11bc58d
commit 0b7d71fd2f
14 changed files with 110 additions and 20 deletions

View File

@@ -6,7 +6,9 @@ from mlx.nn.layers.activations import (
GELU,
GLU,
SELU,
HardShrink,
Hardswish,
HardTanh,
LeakyReLU,
LogSigmoid,
LogSoftmax,
@@ -17,6 +19,7 @@ from mlx.nn.layers.activations import (
Sigmoid,
SiLU,
Softmax,
Softmin,
Softplus,
Softshrink,
Softsign,
@@ -28,6 +31,8 @@ from mlx.nn.layers.activations import (
gelu_approx,
gelu_fast_approx,
glu,
hard_shrink,
hard_tanh,
hardswish,
leaky_relu,
log_sigmoid,
@@ -40,6 +45,7 @@ from mlx.nn.layers.activations import (
sigmoid,
silu,
softmax,
softmin,
softplus,
softshrink,
softsign,

View File

@@ -286,6 +286,38 @@ def hardswish(x):
return x * mx.minimum(max_x_3, 6) / 6
@partial(mx.compile, shapeless=True)
def hard_tanh(x, min_val=-1.0, max_val=1.0):
r"""Applies the HardTanh function.
Applies :math:`\max(\min(x, \text{max\_val}), \text{min\_val})` element-wise.
"""
return mx.minimum(mx.maximum(x, min_val), max_val)
@partial(mx.compile, shapeless=True)
def hard_shrink(x, lambd=0.5):
r"""Applies the HardShrink activation function.
.. math::
\text{hardshrink}(x) = \begin{cases}
x & \text{if } x > \lambda \\
x & \text{if } x < -\lambda \\
0 & \text{otherwise}
\end{cases}
"""
return mx.where(mx.abs(x) > lambd, x, 0)
@partial(mx.compile, shapeless=True)
def softmin(x, axis=-1):
r"""Applies the Softmin function.
Applies :math:`\frac{e^{-x_i}}{\sum_j e^{-x_j}}` element-wise.
"""
return mx.softmax(-x, axis=axis)
def tanh(x):
"""Applies the hyperbolic tangent function.
@@ -579,3 +611,30 @@ class SELU(Module):
See :func:`selu` for the functional equivalent.
"""
@_make_activation_module(hard_tanh)
class HardTanh(Module):
r"""Applies the HardTanh function.
See :func:`hard_tanh` for the functional equivalent.
"""
@_make_activation_module(hard_shrink)
class HardShrink(Module):
r"""Applies the HardShrink function.
See :func:`hard_shrink` for the functional equivalent.
Args:
lambd: the :math:`\lambda` value for Hardshrink. Default: ``0.5``
"""
@_make_activation_module(softmin)
class Softmin(Module):
r"""Applies the Softmin function.
See :func:`softmin` for the functional equivalent.
"""