Fix the hard-shrink test (#1185)

This commit is contained in:
Angelos Katharopoulos 2024-06-04 16:22:56 -07:00 committed by GitHub
parent 0b7d71fd2f
commit 0fe6895893
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -917,13 +917,13 @@ class TestLayers(mlx_tests.MLXTestCase):
def test_hard_shrink(self): def test_hard_shrink(self):
x = mx.array([1.0, -0.5, 0.0, 0.5, -1.5]) x = mx.array([1.0, -0.5, 0.0, 0.5, -1.5])
y = nn.hard_shrink(x) y = nn.hard_shrink(x)
expected_y = mx.array([1.0, -0.5, 0.0, 0.5, -1.5]) expected_y = mx.array([1.0, 0.0, 0.0, 0.0, -1.5])
self.assertTrue(mx.array_equal(y, expected_y)) self.assertTrue(mx.array_equal(y, expected_y))
self.assertEqual(y.shape, (5,)) self.assertEqual(y.shape, (5,))
self.assertEqual(y.dtype, mx.float32) self.assertEqual(y.dtype, mx.float32)
y = nn.hard_shrink(x, lambd=1.0) y = nn.hard_shrink(x, lambd=0.1)
expected_y = mx.array([1.0, 0.0, 0.0, 0.0, -1.5]) expected_y = mx.array([1.0, -0.5, 0.0, 0.5, -1.5])
self.assertTrue(mx.array_equal(y, expected_y)) self.assertTrue(mx.array_equal(y, expected_y))
self.assertEqual(y.shape, (5,)) self.assertEqual(y.shape, (5,))
self.assertEqual(y.dtype, mx.float32) self.assertEqual(y.dtype, mx.float32)