Fix the sigmoid module (#371)

This commit is contained in:
Angelos Katharopoulos
2024-01-04 13:16:36 -08:00
committed by GitHub
parent cf88db44b5
commit 75dc537e44
3 changed files with 19 additions and 9 deletions

View File

@@ -657,6 +657,15 @@ class TestLayers(mlx_tests.MLXTestCase):
mx.abs(similarities[mx.arange(10), mx.arange(10)] - 1).max(), 1e-5
)
def test_sigmoid(self):
x = mx.array([1.0, 0.0, -1.0])
y1 = mx.sigmoid(x)
y2 = nn.activations.sigmoid(x)
y3 = nn.Sigmoid()(x)
self.assertEqualArray(y1, y2, atol=0, rtol=0)
self.assertEqualArray(y1, y3, atol=0, rtol=0)
def test_relu(self):
x = mx.array([1.0, -1.0, 0.0])
y = nn.relu(x)