Softshrink mapping + op (#552)

* Added Softshrink mapping + op

* formatting

* docs + nits in docstring

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Andre Slavescu
2024-01-30 15:56:28 -05:00
committed by GitHub
parent 3f7aba8498
commit d3a9005454
5 changed files with 70 additions and 21 deletions

View File

@@ -751,6 +751,21 @@ class TestLayers(mlx_tests.MLXTestCase):
self.assertEqual(y.shape, [3])
self.assertEqual(y.dtype, mx.float32)
def test_softshrink(self):
x = mx.array([1.0, -1.0, 0.0])
y = nn.softshrink(x)
epsilon = 1e-4
expected_y = mx.array([0.5, -0.5, 0.0])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, [3])
self.assertEqual(y.dtype, mx.float32)
y = nn.Softshrink(lambd=0.7)(x)
expected_y = mx.array([0.3, -0.3, 0.0])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, [3])
self.assertEqual(y.dtype, mx.float32)
def test_celu(self):
x = mx.array([1.0, -1.0, 0.0])
y = nn.celu(x)