Fix the sigmoid module (#371)

This commit is contained in:
Angelos Katharopoulos
2024-01-04 13:16:36 -08:00
committed by GitHub
parent cf88db44b5
commit 75dc537e44
3 changed files with 19 additions and 9 deletions

View File

@@ -13,6 +13,7 @@ from mlx.nn.layers.activations import (
PReLU,
ReLU,
ReLU6,
Sigmoid,
SiLU,
Softmax,
Softplus,

View File

@@ -162,15 +162,6 @@ def gelu_fast_approx(x):
return x * mx.sigmoid(1.773 * x)
@_make_activation_module
class Sigmoid(Module):
r"""Applies the sigmoid function, element-wise.
.. math::
\text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}
"""
def step(x: mx.array, threshold: float = 0.0):
r"""Applies the Step Activation Function.
@@ -240,6 +231,15 @@ def hardswish(x):
return x * mx.minimum(max_x_3, 6) / 6
@_make_activation_module(mx.sigmoid)
class Sigmoid(Module):
r"""Applies the sigmoid function, element-wise.
.. math::
\text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}
"""
@_make_activation_module(mish)
class Mish(Module):
r"""Applies the Mish function, element-wise.