mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 18:48:15 +08:00 
			
		
		
		
	Upsample with bicubic interpolation (#967)
This commit is contained in:
		
							
								
								
									
										99
									
								
								python/tests/test_upsample.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										99
									
								
								python/tests/test_upsample.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,99 @@
 | 
			
		||||
# Copyright © 2023-2024 Apple Inc.
 | 
			
		||||
 | 
			
		||||
import unittest
 | 
			
		||||
 | 
			
		||||
import mlx.core as mx
 | 
			
		||||
import mlx.nn as nn
 | 
			
		||||
import mlx_tests
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
    import torch
 | 
			
		||||
    import torch.nn.functional as F
 | 
			
		||||
 | 
			
		||||
    has_torch = True
 | 
			
		||||
except ImportError as e:
 | 
			
		||||
    has_torch = False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestUpsample(mlx_tests.MLXTestCase):
 | 
			
		||||
    @unittest.skipIf(not has_torch, "requires Torch")
 | 
			
		||||
    def test_torch_upsample(self):
 | 
			
		||||
        def run_upsample(
 | 
			
		||||
            N,
 | 
			
		||||
            C,
 | 
			
		||||
            idim,
 | 
			
		||||
            scale_factor,
 | 
			
		||||
            mode,
 | 
			
		||||
            align_corner,
 | 
			
		||||
            dtype="float32",
 | 
			
		||||
            atol=1e-5,
 | 
			
		||||
        ):
 | 
			
		||||
            with self.subTest(
 | 
			
		||||
                N=N,
 | 
			
		||||
                C=C,
 | 
			
		||||
                idim=idim,
 | 
			
		||||
                scale_factor=scale_factor,
 | 
			
		||||
                mode=mode,
 | 
			
		||||
                align_corner=align_corner,
 | 
			
		||||
            ):
 | 
			
		||||
                np_dtype = getattr(np, dtype)
 | 
			
		||||
                np.random.seed(0)
 | 
			
		||||
                iH, iW = idim
 | 
			
		||||
                in_np = np.random.normal(-1.0, 1.0, (N, iH, iW, C)).astype(np_dtype)
 | 
			
		||||
 | 
			
		||||
                in_mx = mx.array(in_np)
 | 
			
		||||
                in_pt = torch.from_numpy(in_np.transpose(0, 3, 1, 2)).to("cpu")
 | 
			
		||||
 | 
			
		||||
                out_mx = nn.Upsample(
 | 
			
		||||
                    scale_factor=scale_factor,
 | 
			
		||||
                    mode=mode,
 | 
			
		||||
                    align_corners=align_corner,
 | 
			
		||||
                )(in_mx)
 | 
			
		||||
                mode_pt = {
 | 
			
		||||
                    "linear": "bilinear",
 | 
			
		||||
                    "cubic": "bicubic",
 | 
			
		||||
                }[mode]
 | 
			
		||||
                out_pt = F.interpolate(
 | 
			
		||||
                    in_pt,
 | 
			
		||||
                    scale_factor=scale_factor,
 | 
			
		||||
                    mode=mode_pt,
 | 
			
		||||
                    align_corners=align_corner,
 | 
			
		||||
                )
 | 
			
		||||
                out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True)
 | 
			
		||||
                self.assertEqual(out_pt.shape, out_mx.shape)
 | 
			
		||||
                self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
 | 
			
		||||
 | 
			
		||||
        for dtype in ("float32",):
 | 
			
		||||
            for N, C in ((1, 1), (2, 3)):
 | 
			
		||||
                # only test cases in which target sizes are intergers
 | 
			
		||||
                # if not, there will be numerical difference between mlx
 | 
			
		||||
                # and torch due to different indices selection.
 | 
			
		||||
                for idim, scale_factor in (
 | 
			
		||||
                    ((2, 2), (1.0, 1.0)),
 | 
			
		||||
                    ((2, 2), (1.5, 1.5)),
 | 
			
		||||
                    ((2, 2), (2.0, 2.0)),
 | 
			
		||||
                    ((4, 4), (0.5, 0.5)),
 | 
			
		||||
                    ((7, 7), (2.0, 2.0)),
 | 
			
		||||
                    ((10, 10), (0.2, 0.2)),
 | 
			
		||||
                    ((11, 21), (3.0, 3.0)),
 | 
			
		||||
                    ((11, 21), (3.0, 2.0)),
 | 
			
		||||
                ):
 | 
			
		||||
                    # only test linear and cubic interpolation
 | 
			
		||||
                    # there will be numerical difference in nearest
 | 
			
		||||
                    # due to different indices selection.
 | 
			
		||||
                    for mode in ("cubic", "linear"):
 | 
			
		||||
                        for align_corner in (False, True):
 | 
			
		||||
                            run_upsample(
 | 
			
		||||
                                N,
 | 
			
		||||
                                C,
 | 
			
		||||
                                idim,
 | 
			
		||||
                                scale_factor,
 | 
			
		||||
                                mode,
 | 
			
		||||
                                align_corner,
 | 
			
		||||
                                dtype=dtype,
 | 
			
		||||
                            )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    unittest.main()
 | 
			
		||||
		Reference in New Issue
	
	Block a user