mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	Softshrink mapping + op (#552)
* Added Softshrink mapping + op * formatting * docs + nits in docstring --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
		| @@ -751,6 +751,21 @@ class TestLayers(mlx_tests.MLXTestCase): | ||||
|         self.assertEqual(y.shape, [3]) | ||||
|         self.assertEqual(y.dtype, mx.float32) | ||||
|  | ||||
|     def test_softshrink(self): | ||||
|         x = mx.array([1.0, -1.0, 0.0]) | ||||
|         y = nn.softshrink(x) | ||||
|         epsilon = 1e-4 | ||||
|         expected_y = mx.array([0.5, -0.5, 0.0]) | ||||
|         self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon)) | ||||
|         self.assertEqual(y.shape, [3]) | ||||
|         self.assertEqual(y.dtype, mx.float32) | ||||
|  | ||||
|         y = nn.Softshrink(lambd=0.7)(x) | ||||
|         expected_y = mx.array([0.3, -0.3, 0.0]) | ||||
|         self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon)) | ||||
|         self.assertEqual(y.shape, [3]) | ||||
|         self.assertEqual(y.dtype, mx.float32) | ||||
|  | ||||
|     def test_celu(self): | ||||
|         x = mx.array([1.0, -1.0, 0.0]) | ||||
|         y = nn.celu(x) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Andre Slavescu
					Andre Slavescu