mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	Fix the hard-shrink test (#1185)
This commit is contained in:
		 Angelos Katharopoulos
					Angelos Katharopoulos
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							0b7d71fd2f
						
					
				
				
					commit
					0fe6895893
				
			| @@ -917,13 +917,13 @@ class TestLayers(mlx_tests.MLXTestCase): | ||||
|     def test_hard_shrink(self): | ||||
|         x = mx.array([1.0, -0.5, 0.0, 0.5, -1.5]) | ||||
|         y = nn.hard_shrink(x) | ||||
|         expected_y = mx.array([1.0, -0.5, 0.0, 0.5, -1.5]) | ||||
|         expected_y = mx.array([1.0, 0.0, 0.0, 0.0, -1.5]) | ||||
|         self.assertTrue(mx.array_equal(y, expected_y)) | ||||
|         self.assertEqual(y.shape, (5,)) | ||||
|         self.assertEqual(y.dtype, mx.float32) | ||||
|  | ||||
|         y = nn.hard_shrink(x, lambd=1.0) | ||||
|         expected_y = mx.array([1.0, 0.0, 0.0, 0.0, -1.5]) | ||||
|         y = nn.hard_shrink(x, lambd=0.1) | ||||
|         expected_y = mx.array([1.0, -0.5, 0.0, 0.5, -1.5]) | ||||
|         self.assertTrue(mx.array_equal(y, expected_y)) | ||||
|         self.assertEqual(y.shape, (5,)) | ||||
|         self.assertEqual(y.dtype, mx.float32) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user