mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	 Angelos Katharopoulos
					Angelos Katharopoulos
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							4c1dfa58b7
						
					
				
				
					commit
					71de73a668
				
			| @@ -341,7 +341,7 @@ class TestConv(mlx_tests.MLXTestCase): | ||||
|                     atol, rtol = 1e-1, 1e-3 | ||||
|                 else: | ||||
|                     atol, rtol = 1e-5, 1e-6 | ||||
|                 self.assertTrue(np.allclose(out_pt, out_mx, atol=atol, rtol=rtol)) | ||||
|                 self.assertTrue(np.allclose(out_pt, out_mx, atol=atol)) | ||||
|  | ||||
|         for dtype in ("float32", "bfloat16"): | ||||
|             for N, C, O in ( | ||||
| @@ -1042,6 +1042,14 @@ class TestConv(mlx_tests.MLXTestCase): | ||||
|         self.assertTrue(mx.allclose(expected[0], grads[0])) | ||||
|         self.assertTrue(mx.allclose(expected[1], grads[1])) | ||||
|  | ||||
|     def test_repeated_conv(self): | ||||
|         x = mx.random.normal((1, 3, 3, 320)) | ||||
|         w = mx.random.normal((320, 3, 3, 320)) | ||||
|         for i in range(8): | ||||
|             y1 = mx.conv2d(x, w, (1, 1), (1, 1), (1, 1), 1) | ||||
|             y2 = mx.conv2d(x, w, (1, 1), (1, 1), (1, 1), 1) | ||||
|             self.assertTrue(mx.allclose(y1, y2)) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user