mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 08:10:15 +08:00
Softshrink mapping + op (#552)
* Added Softshrink mapping + op * formatting * docs + nits in docstring --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -751,6 +751,21 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(y.shape, [3])
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_softshrink(self):
|
||||
x = mx.array([1.0, -1.0, 0.0])
|
||||
y = nn.softshrink(x)
|
||||
epsilon = 1e-4
|
||||
expected_y = mx.array([0.5, -0.5, 0.0])
|
||||
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||
self.assertEqual(y.shape, [3])
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
y = nn.Softshrink(lambd=0.7)(x)
|
||||
expected_y = mx.array([0.3, -0.3, 0.0])
|
||||
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
|
||||
self.assertEqual(y.shape, [3])
|
||||
self.assertEqual(y.dtype, mx.float32)
|
||||
|
||||
def test_celu(self):
|
||||
x = mx.array([1.0, -1.0, 0.0])
|
||||
y = nn.celu(x)
|
||||
|
Reference in New Issue
Block a user