mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Nearest upsample (#2202)
This commit is contained in:
		 Angelos Katharopoulos
					Angelos Katharopoulos
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							237f9e58a8
						
					
				
				
					commit
					0359bf02c9
				
			| @@ -51,6 +51,7 @@ class TestUpsample(mlx_tests.MLXTestCase): | ||||
|                     align_corners=align_corner, | ||||
|                 )(in_mx) | ||||
|                 mode_pt = { | ||||
|                     "nearest": "nearest", | ||||
|                     "linear": "bilinear", | ||||
|                     "cubic": "bicubic", | ||||
|                 }[mode] | ||||
| @@ -58,7 +59,7 @@ class TestUpsample(mlx_tests.MLXTestCase): | ||||
|                     in_pt, | ||||
|                     scale_factor=scale_factor, | ||||
|                     mode=mode_pt, | ||||
|                     align_corners=align_corner, | ||||
|                     align_corners=align_corner if mode != "nearest" else None, | ||||
|                 ) | ||||
|                 out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True) | ||||
|                 self.assertEqual(out_pt.shape, out_mx.shape) | ||||
| @@ -76,14 +77,14 @@ class TestUpsample(mlx_tests.MLXTestCase): | ||||
|                     ((4, 4), (0.5, 0.5)), | ||||
|                     ((7, 7), (2.0, 2.0)), | ||||
|                     ((10, 10), (0.2, 0.2)), | ||||
|                     ((10, 10), (0.3, 0.3)), | ||||
|                     ((11, 21), (3.0, 3.0)), | ||||
|                     ((11, 21), (3.0, 2.0)), | ||||
|                 ): | ||||
|                     # only test linear and cubic interpolation | ||||
|                     # there will be numerical difference in nearest | ||||
|                     # due to different indices selection. | ||||
|                     for mode in ("cubic", "linear"): | ||||
|                     for mode in ("cubic", "linear", "nearest"): | ||||
|                         for align_corner in (False, True): | ||||
|                             if mode == "nearest" and align_corner: | ||||
|                                 continue | ||||
|                             run_upsample( | ||||
|                                 N, | ||||
|                                 C, | ||||
|   | ||||
		Reference in New Issue
	
	Block a user