mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 12:06:42 +08:00
Fix the hard-shrink test (#1185)
This commit is contained in:
parent
0b7d71fd2f
commit
0fe6895893
@ -917,13 +917,13 @@ class TestLayers(mlx_tests.MLXTestCase):
|
|||||||
def test_hard_shrink(self):
|
def test_hard_shrink(self):
|
||||||
x = mx.array([1.0, -0.5, 0.0, 0.5, -1.5])
|
x = mx.array([1.0, -0.5, 0.0, 0.5, -1.5])
|
||||||
y = nn.hard_shrink(x)
|
y = nn.hard_shrink(x)
|
||||||
expected_y = mx.array([1.0, -0.5, 0.0, 0.5, -1.5])
|
expected_y = mx.array([1.0, 0.0, 0.0, 0.0, -1.5])
|
||||||
self.assertTrue(mx.array_equal(y, expected_y))
|
self.assertTrue(mx.array_equal(y, expected_y))
|
||||||
self.assertEqual(y.shape, (5,))
|
self.assertEqual(y.shape, (5,))
|
||||||
self.assertEqual(y.dtype, mx.float32)
|
self.assertEqual(y.dtype, mx.float32)
|
||||||
|
|
||||||
y = nn.hard_shrink(x, lambd=1.0)
|
y = nn.hard_shrink(x, lambd=0.1)
|
||||||
expected_y = mx.array([1.0, 0.0, 0.0, 0.0, -1.5])
|
expected_y = mx.array([1.0, -0.5, 0.0, 0.5, -1.5])
|
||||||
self.assertTrue(mx.array_equal(y, expected_y))
|
self.assertTrue(mx.array_equal(y, expected_y))
|
||||||
self.assertEqual(y.shape, (5,))
|
self.assertEqual(y.shape, (5,))
|
||||||
self.assertEqual(y.dtype, mx.float32)
|
self.assertEqual(y.dtype, mx.float32)
|
||||||
|
Loading…
Reference in New Issue
Block a user