Fix the sigmoid module (#371)

This commit is contained in:
Angelos Katharopoulos 2024-01-04 13:16:36 -08:00 committed by GitHub
parent cf88db44b5
commit 75dc537e44
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 19 additions and 9 deletions

View File

@ -13,6 +13,7 @@ from mlx.nn.layers.activations import (
PReLU,
ReLU,
ReLU6,
Sigmoid,
SiLU,
Softmax,
Softplus,

View File

@ -162,15 +162,6 @@ def gelu_fast_approx(x):
return x * mx.sigmoid(1.773 * x)
@_make_activation_module
class Sigmoid(Module):
r"""Applies the sigmoid function, element-wise.
.. math::
\text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}
"""
def step(x: mx.array, threshold: float = 0.0):
r"""Applies the Step Activation Function.
@ -240,6 +231,15 @@ def hardswish(x):
return x * mx.minimum(max_x_3, 6) / 6
@_make_activation_module(mx.sigmoid)
class Sigmoid(Module):
r"""Applies the sigmoid function, element-wise.
.. math::
\text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}
"""
@_make_activation_module(mish)
class Mish(Module):
r"""Applies the Mish function, element-wise.

View File

@ -657,6 +657,15 @@ class TestLayers(mlx_tests.MLXTestCase):
mx.abs(similarities[mx.arange(10), mx.arange(10)] - 1).max(), 1e-5
)
def test_sigmoid(self):
x = mx.array([1.0, 0.0, -1.0])
y1 = mx.sigmoid(x)
y2 = nn.activations.sigmoid(x)
y3 = nn.Sigmoid()(x)
self.assertEqualArray(y1, y2, atol=0, rtol=0)
self.assertEqualArray(y1, y3, atol=0, rtol=0)
def test_relu(self):
x = mx.array([1.0, -1.0, 0.0])
y = nn.relu(x)