mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 08:10:15 +08:00
Fix the sigmoid module (#371)
This commit is contained in:

committed by
GitHub

parent
cf88db44b5
commit
75dc537e44
@@ -657,6 +657,15 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
mx.abs(similarities[mx.arange(10), mx.arange(10)] - 1).max(), 1e-5
|
||||
)
|
||||
|
||||
def test_sigmoid(self):
|
||||
x = mx.array([1.0, 0.0, -1.0])
|
||||
y1 = mx.sigmoid(x)
|
||||
y2 = nn.activations.sigmoid(x)
|
||||
y3 = nn.Sigmoid()(x)
|
||||
|
||||
self.assertEqualArray(y1, y2, atol=0, rtol=0)
|
||||
self.assertEqualArray(y1, y3, atol=0, rtol=0)
|
||||
|
||||
def test_relu(self):
|
||||
x = mx.array([1.0, -1.0, 0.0])
|
||||
y = nn.relu(x)
|
||||
|
Reference in New Issue
Block a user