Activations LeakyReLU / PReLU / Softplus / Mish (#109)

* Leaky_relu / prelu / softplus / mish

* added tests

* updated bench

* remove torch refs, add init to PReLU

* added arvix reference to mish

* added missing docs
This commit is contained in:
Diogo
2023-12-11 22:40:57 -05:00
committed by GitHub
parent f5df47ec6e
commit 02de234ef0
8 changed files with 133 additions and 31 deletions

View File

@@ -7,6 +7,8 @@ from mlx.nn.layers.activations import (
SELU,
LeakyReLU,
LogSigmoid,
Mish,
PReLU,
ReLU,
ReLU6,
SiLU,
@@ -19,6 +21,8 @@ from mlx.nn.layers.activations import (
gelu_fast_approx,
leaky_relu,
log_sigmoid,
mish,
prelu,
relu,
relu6,
selu,

View File

@@ -176,6 +176,33 @@ def selu(x):
See also :func:`elu`.
"""
return elu(x, 1.67326) * 1.0507
def prelu(x: mx.array, alpha: mx.array) -> mx.array:
r"""Applies the element-wise function:
.. math::
\text{PReLU}(x) = \max(0,x) + a * \min(0,x)
Here :math:`a` is an array.
"""
return mx.maximum(0, x) + alpha * mx.minimum(0, x)
def mish(x: mx.array) -> mx.array:
r"""Applies the Mish function, element-wise.
Mish: A Self Regularized Non-Monotonic Neural Activation Function.
Reference: https://arxiv.org/abs/1908.08681
.. math::
\text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
"""
return x * mx.tanh(softplus(x))
@_make_activation_module(mish)
class Mish(Module):
pass
@_make_activation_module(relu)
@@ -257,6 +284,15 @@ class LogSigmoid(Module):
pass
class PReLU(Module):
def __init__(self, num_parameters=1, init=0.25):
super().__init__()
self.weight = mx.full([num_parameters], init)
def __call__(self, x: mx.array):
return prelu(x, self.weight)
class GELU(Module):
r"""Applies the Gaussian Error Linear Units.