mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 18:28:12 +08:00
Activations LeakyReLU / PReLU / Softplus / Mish (#109)
* Leaky_relu / prelu / softplus / mish * added tests * updated bench * remove torch refs, add init to PReLU * added arvix reference to mish * added missing docs
This commit is contained in:
@@ -7,6 +7,8 @@ from mlx.nn.layers.activations import (
|
||||
SELU,
|
||||
LeakyReLU,
|
||||
LogSigmoid,
|
||||
Mish,
|
||||
PReLU,
|
||||
ReLU,
|
||||
ReLU6,
|
||||
SiLU,
|
||||
@@ -19,6 +21,8 @@ from mlx.nn.layers.activations import (
|
||||
gelu_fast_approx,
|
||||
leaky_relu,
|
||||
log_sigmoid,
|
||||
mish,
|
||||
prelu,
|
||||
relu,
|
||||
relu6,
|
||||
selu,
|
||||
|
@@ -176,6 +176,33 @@ def selu(x):
|
||||
See also :func:`elu`.
|
||||
"""
|
||||
return elu(x, 1.67326) * 1.0507
|
||||
def prelu(x: mx.array, alpha: mx.array) -> mx.array:
|
||||
r"""Applies the element-wise function:
|
||||
|
||||
.. math::
|
||||
\text{PReLU}(x) = \max(0,x) + a * \min(0,x)
|
||||
|
||||
Here :math:`a` is an array.
|
||||
"""
|
||||
return mx.maximum(0, x) + alpha * mx.minimum(0, x)
|
||||
|
||||
|
||||
def mish(x: mx.array) -> mx.array:
|
||||
r"""Applies the Mish function, element-wise.
|
||||
Mish: A Self Regularized Non-Monotonic Neural Activation Function.
|
||||
|
||||
Reference: https://arxiv.org/abs/1908.08681
|
||||
|
||||
.. math::
|
||||
\text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
|
||||
|
||||
"""
|
||||
return x * mx.tanh(softplus(x))
|
||||
|
||||
|
||||
@_make_activation_module(mish)
|
||||
class Mish(Module):
|
||||
pass
|
||||
|
||||
|
||||
@_make_activation_module(relu)
|
||||
@@ -257,6 +284,15 @@ class LogSigmoid(Module):
|
||||
pass
|
||||
|
||||
|
||||
class PReLU(Module):
|
||||
def __init__(self, num_parameters=1, init=0.25):
|
||||
super().__init__()
|
||||
self.weight = mx.full([num_parameters], init)
|
||||
|
||||
def __call__(self, x: mx.array):
|
||||
return prelu(x, self.weight)
|
||||
|
||||
|
||||
class GELU(Module):
|
||||
r"""Applies the Gaussian Error Linear Units.
|
||||
|
||||
|
Reference in New Issue
Block a user