mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Fix the sigmoid module (#371)
This commit is contained in:
parent
cf88db44b5
commit
75dc537e44
@ -13,6 +13,7 @@ from mlx.nn.layers.activations import (
|
||||
PReLU,
|
||||
ReLU,
|
||||
ReLU6,
|
||||
Sigmoid,
|
||||
SiLU,
|
||||
Softmax,
|
||||
Softplus,
|
||||
|
@ -162,15 +162,6 @@ def gelu_fast_approx(x):
|
||||
return x * mx.sigmoid(1.773 * x)
|
||||
|
||||
|
||||
@_make_activation_module
|
||||
class Sigmoid(Module):
|
||||
r"""Applies the sigmoid function, element-wise.
|
||||
|
||||
.. math::
|
||||
\text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}
|
||||
"""
|
||||
|
||||
|
||||
def step(x: mx.array, threshold: float = 0.0):
|
||||
r"""Applies the Step Activation Function.
|
||||
|
||||
@ -240,6 +231,15 @@ def hardswish(x):
|
||||
return x * mx.minimum(max_x_3, 6) / 6
|
||||
|
||||
|
||||
@_make_activation_module(mx.sigmoid)
|
||||
class Sigmoid(Module):
|
||||
r"""Applies the sigmoid function, element-wise.
|
||||
|
||||
.. math::
|
||||
\text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}
|
||||
"""
|
||||
|
||||
|
||||
@_make_activation_module(mish)
|
||||
class Mish(Module):
|
||||
r"""Applies the Mish function, element-wise.
|
||||
|
@ -657,6 +657,15 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
mx.abs(similarities[mx.arange(10), mx.arange(10)] - 1).max(), 1e-5
|
||||
)
|
||||
|
||||
def test_sigmoid(self):
|
||||
x = mx.array([1.0, 0.0, -1.0])
|
||||
y1 = mx.sigmoid(x)
|
||||
y2 = nn.activations.sigmoid(x)
|
||||
y3 = nn.Sigmoid()(x)
|
||||
|
||||
self.assertEqualArray(y1, y2, atol=0, rtol=0)
|
||||
self.assertEqualArray(y1, y3, atol=0, rtol=0)
|
||||
|
||||
def test_relu(self):
|
||||
x = mx.array([1.0, -1.0, 0.0])
|
||||
y = nn.relu(x)
|
||||
|
Loading…
Reference in New Issue
Block a user