mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 10:26:56 +08:00
Added activation functions: leaky_relu relu6 softplus elu celu logsigmoid (#108)
* added leaky_relu relu6 softplus elu celu logsigmoid * minor fixes for docstring and benchmark imports * fixed elu implementation and added tests * added tests for optional param, changed leaky_relu param to fit pytorch documentation
This commit is contained in:
@@ -1,14 +1,26 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
from mlx.nn.layers.activations import (
|
||||
CELU,
|
||||
ELU,
|
||||
GELU,
|
||||
LeakyReLU,
|
||||
LogSigmoid,
|
||||
ReLU,
|
||||
ReLU6,
|
||||
SiLU,
|
||||
Softplus,
|
||||
celu,
|
||||
elu,
|
||||
gelu,
|
||||
gelu_approx,
|
||||
gelu_fast_approx,
|
||||
leaky_relu,
|
||||
log_sigmoid,
|
||||
relu,
|
||||
relu6,
|
||||
silu,
|
||||
softplus,
|
||||
)
|
||||
from mlx.nn.layers.base import Module
|
||||
from mlx.nn.layers.containers import Sequential
|
||||
|
@@ -32,6 +32,47 @@ def relu(x):
|
||||
return mx.maximum(x, 0)
|
||||
|
||||
|
||||
def leaky_relu(x, negative_slope=0.01):
|
||||
"""Applies the Leaky Rectified Linear Unit.
|
||||
|
||||
Simply ``mx.maximum(negative_slope * x, x)``.
|
||||
"""
|
||||
return mx.maximum(negative_slope * x, x)
|
||||
|
||||
|
||||
def elu(x, alpha=1.0):
|
||||
"""Applies the Exponential Linear Unit.
|
||||
|
||||
Simply ``mx.where(x > 0, x, alpha * (mx.exp(x) - 1))``.
|
||||
"""
|
||||
return mx.where(x > 0, x, alpha * (mx.exp(x) - 1))
|
||||
|
||||
|
||||
def relu6(x):
|
||||
r"""Applies the Rectified Linear Unit 6.
|
||||
|
||||
Applies :math:`\min(\max(x, 0), 6)` element wise.
|
||||
"""
|
||||
return mx.minimum(mx.maximum(x, 0), 6.0)
|
||||
|
||||
|
||||
def softplus(x):
|
||||
r"""Applies the Softplus function.
|
||||
|
||||
Applies :math:`\log(1 + \exp(x))` element wise.
|
||||
"""
|
||||
return mx.logaddexp(x, 0)
|
||||
|
||||
|
||||
def celu(x, alpha=1.0):
|
||||
r"""Applies the Continuously Differentiable Exponential Linear Unit.
|
||||
|
||||
Applies :math:`\max(0, x) + \min(0, \alpha * (\exp(x / \alpha) - 1))`
|
||||
element wise.
|
||||
"""
|
||||
return mx.maximum(x, 0.0) + alpha * (mx.exp(mx.minimum(x, 0.0) / alpha) - 1)
|
||||
|
||||
|
||||
def silu(x):
|
||||
r"""Applies the Sigmoid Linear Unit.
|
||||
|
||||
@@ -41,6 +82,14 @@ def silu(x):
|
||||
return x * mx.sigmoid(x)
|
||||
|
||||
|
||||
def log_sigmoid(x):
|
||||
r"""Applies the Log Sigmoid function.
|
||||
|
||||
Applies :math:`\log(\sigma(x)) = -\log(1 + e^{-x})` element wise.
|
||||
"""
|
||||
return -softplus(-x)
|
||||
|
||||
|
||||
def gelu(x):
|
||||
r"""Applies the Gaussian Error Linear Units function.
|
||||
|
||||
@@ -99,11 +148,80 @@ class ReLU(Module):
|
||||
pass
|
||||
|
||||
|
||||
class LeakyReLU(Module):
|
||||
r"""Applies the Leaky Rectified Linear Unit.
|
||||
|
||||
Simply ``mx.maximum(negative_slope * x, x)``.
|
||||
|
||||
Args:
|
||||
negative_slope: Controls the angle of the negative slope. Default: 1e-2.
|
||||
"""
|
||||
|
||||
def __init__(self, negative_slope=1e-2):
|
||||
super().__init__()
|
||||
self._negative_slope = negative_slope
|
||||
|
||||
def __call__(self, x):
|
||||
return leaky_relu(x, self._negative_slope)
|
||||
|
||||
|
||||
class ELU(Module):
|
||||
r"""Applies the Exponential Linear Unit.
|
||||
Simply ``mx.where(x > 0, x, alpha * (mx.exp(x) - 1))``.
|
||||
|
||||
See :func:`elu`, for the functional equivalent.
|
||||
|
||||
Args:
|
||||
alpha: the :math:`\alpha` value for the ELU formulation. Default: 1.0
|
||||
"""
|
||||
|
||||
def __init__(self, alpha=1.0):
|
||||
super().__init__()
|
||||
self._alpha = alpha
|
||||
|
||||
def __call__(self, x):
|
||||
return elu(x, self._alpha)
|
||||
|
||||
|
||||
@_make_activation_module(relu6)
|
||||
class ReLU6(Module):
|
||||
pass
|
||||
|
||||
|
||||
@_make_activation_module(softplus)
|
||||
class Softplus(Module):
|
||||
pass
|
||||
|
||||
|
||||
class CELU(Module):
|
||||
r"""Applies the Continuously Differentiable Exponential Linear Unit.
|
||||
Applies :math:`\max(0, x) + \min(0, \alpha * (\exp(x / \alpha) - 1))`
|
||||
element wise.
|
||||
|
||||
See :func:`celu`, for the functional equivalent.
|
||||
|
||||
Args:
|
||||
alpha: the :math:`\alpha` value for the CELU formulation. Default: 1.0
|
||||
"""
|
||||
|
||||
def __init__(self, alpha=1.0):
|
||||
super().__init__()
|
||||
self._alpha = alpha
|
||||
|
||||
def __call__(self, x):
|
||||
return celu(x, self._alpha)
|
||||
|
||||
|
||||
@_make_activation_module(silu)
|
||||
class SiLU(Module):
|
||||
pass
|
||||
|
||||
|
||||
@_make_activation_module(log_sigmoid)
|
||||
class LogSigmoid(Module):
|
||||
pass
|
||||
|
||||
|
||||
class GELU(Module):
|
||||
r"""Applies the Gaussian Error Linear Units.
|
||||
|
||||
|
Reference in New Issue
Block a user