mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			101 lines
		
	
	
		
			3.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			101 lines
		
	
	
		
			3.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright © 2023-2024 Apple Inc.
 | |
| 
 | |
| import unittest
 | |
| 
 | |
| import mlx.core as mx
 | |
| import mlx.nn as nn
 | |
| import mlx_tests
 | |
| import numpy as np
 | |
| 
 | |
| try:
 | |
|     import torch
 | |
|     import torch.nn.functional as F
 | |
| 
 | |
|     has_torch = True
 | |
| except ImportError as e:
 | |
|     has_torch = False
 | |
| 
 | |
| 
 | |
| class TestUpsample(mlx_tests.MLXTestCase):
 | |
|     @unittest.skipIf(not has_torch, "requires Torch")
 | |
|     def test_torch_upsample(self):
 | |
|         def run_upsample(
 | |
|             N,
 | |
|             C,
 | |
|             idim,
 | |
|             scale_factor,
 | |
|             mode,
 | |
|             align_corner,
 | |
|             dtype="float32",
 | |
|             atol=1e-5,
 | |
|         ):
 | |
|             with self.subTest(
 | |
|                 N=N,
 | |
|                 C=C,
 | |
|                 idim=idim,
 | |
|                 scale_factor=scale_factor,
 | |
|                 mode=mode,
 | |
|                 align_corner=align_corner,
 | |
|             ):
 | |
|                 np_dtype = getattr(np, dtype)
 | |
|                 np.random.seed(0)
 | |
|                 iH, iW = idim
 | |
|                 in_np = np.random.normal(-1.0, 1.0, (N, iH, iW, C)).astype(np_dtype)
 | |
| 
 | |
|                 in_mx = mx.array(in_np)
 | |
|                 in_pt = torch.from_numpy(in_np.transpose(0, 3, 1, 2)).to("cpu")
 | |
| 
 | |
|                 out_mx = nn.Upsample(
 | |
|                     scale_factor=scale_factor,
 | |
|                     mode=mode,
 | |
|                     align_corners=align_corner,
 | |
|                 )(in_mx)
 | |
|                 mode_pt = {
 | |
|                     "nearest": "nearest",
 | |
|                     "linear": "bilinear",
 | |
|                     "cubic": "bicubic",
 | |
|                 }[mode]
 | |
|                 out_pt = F.interpolate(
 | |
|                     in_pt,
 | |
|                     scale_factor=scale_factor,
 | |
|                     mode=mode_pt,
 | |
|                     align_corners=align_corner if mode != "nearest" else None,
 | |
|                 )
 | |
|                 out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True)
 | |
|                 self.assertEqual(out_pt.shape, out_mx.shape)
 | |
|                 self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
 | |
| 
 | |
|         for dtype in ("float32",):
 | |
|             for N, C in ((1, 1), (2, 3)):
 | |
|                 # only test cases in which target sizes are intergers
 | |
|                 # if not, there will be numerical difference between mlx
 | |
|                 # and torch due to different indices selection.
 | |
|                 for idim, scale_factor in (
 | |
|                     ((2, 2), (1.0, 1.0)),
 | |
|                     ((2, 2), (1.5, 1.5)),
 | |
|                     ((2, 2), (2.0, 2.0)),
 | |
|                     ((4, 4), (0.5, 0.5)),
 | |
|                     ((7, 7), (2.0, 2.0)),
 | |
|                     ((10, 10), (0.2, 0.2)),
 | |
|                     ((10, 10), (0.3, 0.3)),
 | |
|                     ((11, 21), (3.0, 3.0)),
 | |
|                     ((11, 21), (3.0, 2.0)),
 | |
|                 ):
 | |
|                     for mode in ("cubic", "linear", "nearest"):
 | |
|                         for align_corner in (False, True):
 | |
|                             if mode == "nearest" and align_corner:
 | |
|                                 continue
 | |
|                             run_upsample(
 | |
|                                 N,
 | |
|                                 C,
 | |
|                                 idim,
 | |
|                                 scale_factor,
 | |
|                                 mode,
 | |
|                                 align_corner,
 | |
|                                 dtype=dtype,
 | |
|                             )
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     mlx_tests.MLXTestRunner()
 | 
