mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	Upsample2d (#414)
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com> Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
		 Gabrijel Boduljak
					Gabrijel Boduljak
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							d729a1991b
						
					
				
				
					commit
					22364c40b7
				
			| @@ -67,3 +67,4 @@ from mlx.nn.layers.transformer import ( | ||||
|     TransformerEncoder, | ||||
|     TransformerEncoderLayer, | ||||
| ) | ||||
| from mlx.nn.layers.upsample import Upsample | ||||
|   | ||||
							
								
								
									
										205
									
								
								python/mlx/nn/layers/upsample.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										205
									
								
								python/mlx/nn/layers/upsample.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,205 @@ | ||||
| # Copyright © 2023-2024 Apple Inc. | ||||
|  | ||||
| import operator | ||||
| from functools import reduce | ||||
| from itertools import product | ||||
| from typing import Literal, Tuple, Union | ||||
|  | ||||
| import mlx.core as mx | ||||
| from mlx.nn.layers.base import Module | ||||
|  | ||||
|  | ||||
| def _scaled_indices(N, scale, align_corners, dim, ndims): | ||||
|     M = int(scale * N) | ||||
|     if align_corners: | ||||
|         indices = mx.arange(M, dtype=mx.float32) * ((N - 1) / (M - 1)) | ||||
|     else: | ||||
|         step = 1 / scale | ||||
|         start = ((M - 1) * step - N + 1) / 2 | ||||
|         indices = mx.arange(M, dtype=mx.float32) * step - start | ||||
|         indices = mx.clip(indices, 0, N - 1) | ||||
|     shape = [1] * ndims | ||||
|     shape[dim] = -1 | ||||
|  | ||||
|     return indices.reshape(shape) | ||||
|  | ||||
|  | ||||
| def _nearest_indices(N, scale, dim, ndims): | ||||
|     return _scaled_indices(N, scale, True, dim, ndims).astype(mx.int32) | ||||
|  | ||||
|  | ||||
| def _linear_indices(N, scale, align_corners, dim, ndims): | ||||
|     indices = _scaled_indices(N, scale, align_corners, dim, ndims) | ||||
|     indices_l = mx.floor(indices) | ||||
|     indices_r = mx.ceil(indices) | ||||
|     weight = indices - indices_l | ||||
|     weight = mx.expand_dims(weight, -1) | ||||
|  | ||||
|     return ( | ||||
|         (indices_l.astype(mx.int32), 1 - weight), | ||||
|         (indices_r.astype(mx.int32), weight), | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def upsample_nearest(x: mx.array, scale_factor: Tuple): | ||||
|     dims = x.ndim - 2 | ||||
|     if dims != len(scale_factor): | ||||
|         raise ValueError("A scale needs to be provided for each spatial dimension") | ||||
|  | ||||
|     # Integer scale_factors means we can simply expand-broadcast and reshape | ||||
|     if tuple(map(int, scale_factor)) == scale_factor: | ||||
|         shape = list(x.shape) | ||||
|         for d in range(dims): | ||||
|             shape.insert(2 + 2 * d, 1) | ||||
|         x = x.reshape(shape) | ||||
|         for d in range(dims): | ||||
|             shape[2 + 2 * d] = int(scale_factor[d]) | ||||
|         x = mx.broadcast_to(x, shape) | ||||
|         for d in range(dims): | ||||
|             shape[d + 1] *= shape[d + 2] | ||||
|             shape.pop(d + 2) | ||||
|         x = x.reshape(shape) | ||||
|         return x | ||||
|  | ||||
|     else: | ||||
|         B, *N, C = x.shape | ||||
|         indices = [slice(None)] | ||||
|         for i, (n, s) in enumerate(zip(N, scale_factor)): | ||||
|             indices.append(_nearest_indices(n, s, i, dims)) | ||||
|         indices = tuple(indices) | ||||
|  | ||||
|         return x[indices] | ||||
|  | ||||
|  | ||||
| def upsample_linear(x: mx.array, scale_factor: Tuple, align_corners: bool = False): | ||||
|     dims = x.ndim - 2 | ||||
|     if dims != len(scale_factor): | ||||
|         raise ValueError("A scale needs to be provided for each spatial dimension") | ||||
|  | ||||
|     B, *N, C = x.shape | ||||
|  | ||||
|     # Compute the sampling grid | ||||
|     indices = [] | ||||
|     for i, (n, s) in enumerate(zip(N, scale_factor)): | ||||
|         indices.append(_linear_indices(n, s, align_corners, i, dims)) | ||||
|  | ||||
|     # Sample and compute the weights | ||||
|     samples = [] | ||||
|     weights = [] | ||||
|     for idx_weight in product(*indices): | ||||
|         idx, weight = zip(*idx_weight) | ||||
|         samples.append(x[(slice(None),) + idx]) | ||||
|         weights.append(reduce(operator.mul, weight)) | ||||
|  | ||||
|     # Interpolate | ||||
|     return sum(wi * xi for wi, xi in zip(weights, samples)) | ||||
|  | ||||
|  | ||||
| class Upsample(Module): | ||||
|     r"""Upsample the input signal spatially. | ||||
|  | ||||
|     The spatial dimensions are by convention dimensions ``1`` to ``x.ndim - | ||||
|     2``. The first is the batch dimension and the last is the feature | ||||
|     dimension. | ||||
|  | ||||
|     For example, an audio signal would be 3D with 1 spatial dimension, an image | ||||
|     4D with 2 and so on and so forth. | ||||
|  | ||||
|     There are two upsampling algorithms implemented nearest neighbor upsampling | ||||
|     and linear interpolation. Both can be applied to any number of spatial | ||||
|     dimensions and the linear interpolation will be bilinear, trilinear etc | ||||
|     when applied to more than one spatial dimension. | ||||
|  | ||||
|     .. note:: | ||||
|        When using one of the linear interpolation modes the ``align_corners`` | ||||
|        argument changes how the corners are treated in the input image. If | ||||
|        ``align_corners=True`` then the top and left edge of the input and | ||||
|        output will be matching as will the bottom right edge. | ||||
|  | ||||
|     Parameters: | ||||
|         scale_factor (float or tuple): The multiplier for the spatial size. | ||||
|             If a ``float`` is provided, it is the multiplier for all spatial dimensions. | ||||
|             Otherwise, the number of scale factors provided must match the | ||||
|             number of spatial dimensions. | ||||
|         mode (str, optional): The upsampling algorithm, either ``"nearest"`` or | ||||
|             ``"linear"``. Default: ``"nearest"``. | ||||
|         align_corners (bool, optional): Changes the way the corners are treated | ||||
|             during ``"linear"`` upsampling.  See the note above and the | ||||
|             examples below for more details.  Default: ``False``. | ||||
|  | ||||
|     Examples: | ||||
|         >>> import mlx.core as mx | ||||
|         >>> import mlx.nn as nn | ||||
|         >>> x = mx.arange(1, 5).reshape((1, 2, 2, 1)) | ||||
|         >>> x | ||||
|         array([[[[1], | ||||
|                  [2]], | ||||
|                 [[3], | ||||
|                  [4]]]], dtype=int32) | ||||
|         >>> n = nn.Upsample(scale_factor=2, mode='nearest') | ||||
|         >>> n(x).squeeze() | ||||
|         array([[1, 1, 2, 2], | ||||
|                [1, 1, 2, 2], | ||||
|                [3, 3, 4, 4], | ||||
|                [3, 3, 4, 4]], dtype=int32) | ||||
|         >>> b = nn.Upsample(scale_factor=2, mode='linear') | ||||
|         >>> b(x).squeeze() | ||||
|         array([[1, 1.25, 1.75, 2], | ||||
|                [1.5, 1.75, 2.25, 2.5], | ||||
|                [2.5, 2.75, 3.25, 3.5], | ||||
|                [3, 3.25, 3.75, 4]], dtype=float32) | ||||
|         >>> b = nn.Upsample(scale_factor=2, mode='linear', align_corners=True) | ||||
|         >>> b(x).squeeze() | ||||
|         array([[1, 1.33333, 1.66667, 2], | ||||
|                [1.66667, 2, 2.33333, 2.66667], | ||||
|                [2.33333, 2.66667, 3, 3.33333], | ||||
|                [3, 3.33333, 3.66667, 4]], dtype=float32) | ||||
|     """ | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         scale_factor: Union[float, Tuple], | ||||
|         mode: Literal["nearest", "linear"] = "nearest", | ||||
|         align_corners: bool = False, | ||||
|     ): | ||||
|         super().__init__() | ||||
|         if mode not in ["nearest", "linear"]: | ||||
|             raise ValueError(f"[Upsample] Got unsupported upsampling algorithm: {mode}") | ||||
|         if isinstance(scale_factor, (list, tuple)): | ||||
|             self.scale_factor = tuple(map(float, scale_factor)) | ||||
|         else: | ||||
|             self.scale_factor = float(scale_factor) | ||||
|         self.mode = mode | ||||
|         self.align_corners = align_corners | ||||
|  | ||||
|     def _extra_repr(self) -> str: | ||||
|         return ( | ||||
|             f"scale_factor={self.scale_factor}, mode={self.mode!r}, " | ||||
|             f"align_corners={self.align_corners}" | ||||
|         ) | ||||
|  | ||||
|     def __call__(self, x: mx.array) -> mx.array: | ||||
|         dims = x.ndim - 2 | ||||
|         if dims <= 0: | ||||
|             raise ValueError( | ||||
|                 f"[Upsample] The input should have at least 1 spatial " | ||||
|                 f"dimension which means it should be at least 3D but " | ||||
|                 f"{x.ndim}D was provided" | ||||
|             ) | ||||
|  | ||||
|         scale_factor = self.scale_factor | ||||
|         if isinstance(scale_factor, tuple): | ||||
|             if len(scale_factor) != dims: | ||||
|                 raise ValueError( | ||||
|                     f"[Upsample] One scale per spatial dimension is required but " | ||||
|                     f"scale_factor={scale_factor} and the number of spatial " | ||||
|                     f"dimensions were {dims}" | ||||
|                 ) | ||||
|         else: | ||||
|             scale_factor = (scale_factor,) * dims | ||||
|  | ||||
|         if self.mode == "nearest": | ||||
|             return upsample_nearest(x, scale_factor) | ||||
|  | ||||
|         else: | ||||
|             return upsample_linear(x, scale_factor, self.align_corners) | ||||
| @@ -1,4 +1,4 @@ | ||||
| # Copyright © 2023 Apple Inc. | ||||
| # Copyright © 2023-2024 Apple Inc. | ||||
|  | ||||
| import os | ||||
| import tempfile | ||||
| @@ -8,7 +8,7 @@ import mlx.core as mx | ||||
| import mlx.nn as nn | ||||
| import mlx_tests | ||||
| import numpy as np | ||||
| from mlx.utils import tree_flatten, tree_map, tree_unflatten | ||||
| from mlx.utils import tree_flatten, tree_map | ||||
|  | ||||
|  | ||||
| class TestBase(mlx_tests.MLXTestCase): | ||||
| @@ -905,6 +905,228 @@ class TestLayers(mlx_tests.MLXTestCase): | ||||
|         self.assertTrue(y.shape, x.shape) | ||||
|         self.assertTrue(y.dtype, mx.float16) | ||||
|  | ||||
|     def test_upsample(self): | ||||
|         b, h, w, c = 1, 2, 2, 1 | ||||
|         scale_factor = 2 | ||||
|         upsample_nearest = nn.Upsample( | ||||
|             scale_factor=scale_factor, mode="nearest", align_corners=True | ||||
|         ) | ||||
|         upsample_bilinear = nn.Upsample( | ||||
|             scale_factor=scale_factor, mode="linear", align_corners=True | ||||
|         ) | ||||
|         upsample_nearest = nn.Upsample( | ||||
|             scale_factor=scale_factor, mode="nearest", align_corners=True | ||||
|         ) | ||||
|         upsample_bilinear_no_align_corners = nn.Upsample( | ||||
|             scale_factor=scale_factor, mode="linear", align_corners=False | ||||
|         ) | ||||
|         upsample_nearest_no_align_corners = nn.Upsample( | ||||
|             scale_factor=scale_factor, mode="nearest", align_corners=False | ||||
|         ) | ||||
|         # Test single feature map, align corners | ||||
|         x = mx.arange(b * h * w * c).reshape((b, c, h, w)).transpose((0, 2, 3, 1)) | ||||
|         expected_nearest = mx.array( | ||||
|             [[[[0, 0, 1, 1], [0, 0, 1, 1], [2, 2, 3, 3], [2, 2, 3, 3]]]] | ||||
|         ).transpose((0, 2, 3, 1)) | ||||
|         expected_bilinear = mx.array( | ||||
|             [ | ||||
|                 [ | ||||
|                     [ | ||||
|                         [0, 0.333333, 0.666667, 1], | ||||
|                         [0.666667, 1, 1.33333, 1.66667], | ||||
|                         [1.33333, 1.66667, 2, 2.33333], | ||||
|                         [2, 2.33333, 2.66667, 3], | ||||
|                     ] | ||||
|                 ] | ||||
|             ] | ||||
|         ).transpose((0, 2, 3, 1)) | ||||
|         # Test single feature map, no align corners | ||||
|         x = ( | ||||
|             mx.arange(1, b * h * w * c + 1) | ||||
|             .reshape((b, c, h, w)) | ||||
|             .transpose((0, 2, 3, 1)) | ||||
|         ) | ||||
|         expected_bilinear_no_align_corners = mx.array( | ||||
|             [ | ||||
|                 [ | ||||
|                     [ | ||||
|                         [1.0000, 1.2500, 1.7500, 2.0000], | ||||
|                         [1.5000, 1.7500, 2.2500, 2.5000], | ||||
|                         [2.5000, 2.7500, 3.2500, 3.5000], | ||||
|                         [3.0000, 3.2500, 3.7500, 4.0000], | ||||
|                     ] | ||||
|                 ] | ||||
|             ] | ||||
|         ).transpose((0, 2, 3, 1)) | ||||
|         expected_nearest_no_align_corners = mx.array( | ||||
|             [[[[1, 1, 2, 2], [1, 1, 2, 2], [3, 3, 4, 4], [3, 3, 4, 4]]]] | ||||
|         ).transpose((0, 2, 3, 1)) | ||||
|         self.assertTrue( | ||||
|             np.allclose( | ||||
|                 upsample_nearest_no_align_corners(x), expected_nearest_no_align_corners | ||||
|             ) | ||||
|         ) | ||||
|         self.assertTrue( | ||||
|             np.allclose( | ||||
|                 upsample_bilinear_no_align_corners(x), | ||||
|                 expected_bilinear_no_align_corners, | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|         # Test a more complex batch | ||||
|         b, h, w, c = 2, 3, 3, 2 | ||||
|         scale_factor = 2 | ||||
|         x = mx.arange((b * h * w * c)).reshape((b, c, h, w)).transpose((0, 2, 3, 1)) | ||||
|  | ||||
|         upsample_nearest = nn.Upsample( | ||||
|             scale_factor=scale_factor, mode="nearest", align_corners=True | ||||
|         ) | ||||
|         upsample_bilinear = nn.Upsample( | ||||
|             scale_factor=scale_factor, mode="linear", align_corners=True | ||||
|         ) | ||||
|  | ||||
|         expected_nearest = mx.array( | ||||
|             [ | ||||
|                 [ | ||||
|                     [ | ||||
|                         [0.0, 0.0, 1.0, 1.0, 2.0, 2.0], | ||||
|                         [0.0, 0.0, 1.0, 1.0, 2.0, 2.0], | ||||
|                         [3.0, 3.0, 4.0, 4.0, 5.0, 5.0], | ||||
|                         [3.0, 3.0, 4.0, 4.0, 5.0, 5.0], | ||||
|                         [6.0, 6.0, 7.0, 7.0, 8.0, 8.0], | ||||
|                         [6.0, 6.0, 7.0, 7.0, 8.0, 8.0], | ||||
|                     ], | ||||
|                     [ | ||||
|                         [9.0, 9.0, 10.0, 10.0, 11.0, 11.0], | ||||
|                         [9.0, 9.0, 10.0, 10.0, 11.0, 11.0], | ||||
|                         [12.0, 12.0, 13.0, 13.0, 14.0, 14.0], | ||||
|                         [12.0, 12.0, 13.0, 13.0, 14.0, 14.0], | ||||
|                         [15.0, 15.0, 16.0, 16.0, 17.0, 17.0], | ||||
|                         [15.0, 15.0, 16.0, 16.0, 17.0, 17.0], | ||||
|                     ], | ||||
|                 ], | ||||
|                 [ | ||||
|                     [ | ||||
|                         [18.0, 18.0, 19.0, 19.0, 20.0, 20.0], | ||||
|                         [18.0, 18.0, 19.0, 19.0, 20.0, 20.0], | ||||
|                         [21.0, 21.0, 22.0, 22.0, 23.0, 23.0], | ||||
|                         [21.0, 21.0, 22.0, 22.0, 23.0, 23.0], | ||||
|                         [24.0, 24.0, 25.0, 25.0, 26.0, 26.0], | ||||
|                         [24.0, 24.0, 25.0, 25.0, 26.0, 26.0], | ||||
|                     ], | ||||
|                     [ | ||||
|                         [27.0, 27.0, 28.0, 28.0, 29.0, 29.0], | ||||
|                         [27.0, 27.0, 28.0, 28.0, 29.0, 29.0], | ||||
|                         [30.0, 30.0, 31.0, 31.0, 32.0, 32.0], | ||||
|                         [30.0, 30.0, 31.0, 31.0, 32.0, 32.0], | ||||
|                         [33.0, 33.0, 34.0, 34.0, 35.0, 35.0], | ||||
|                         [33.0, 33.0, 34.0, 34.0, 35.0, 35.0], | ||||
|                     ], | ||||
|                 ], | ||||
|             ] | ||||
|         ).transpose((0, 2, 3, 1)) | ||||
|         expected_bilinear = mx.array( | ||||
|             [ | ||||
|                 [ | ||||
|                     [ | ||||
|                         [0.0, 0.4, 0.8, 1.2, 1.6, 2.0], | ||||
|                         [1.2, 1.6, 2.0, 2.4, 2.8, 3.2], | ||||
|                         [2.4, 2.8, 3.2, 3.6, 4.0, 4.4], | ||||
|                         [3.6, 4.0, 4.4, 4.8, 5.2, 5.6], | ||||
|                         [4.8, 5.2, 5.6, 6.0, 6.4, 6.8], | ||||
|                         [6.0, 6.4, 6.8, 7.2, 7.6, 8.0], | ||||
|                     ], | ||||
|                     [ | ||||
|                         [9.0, 9.4, 9.8, 10.2, 10.6, 11.0], | ||||
|                         [10.2, 10.6, 11.0, 11.4, 11.8, 12.2], | ||||
|                         [11.4, 11.8, 12.2, 12.6, 13.0, 13.4], | ||||
|                         [12.6, 13.0, 13.4, 13.8, 14.2, 14.6], | ||||
|                         [13.8, 14.2, 14.6, 15.0, 15.4, 15.8], | ||||
|                         [15.0, 15.4, 15.8, 16.2, 16.6, 17.0], | ||||
|                     ], | ||||
|                 ], | ||||
|                 [ | ||||
|                     [ | ||||
|                         [18.0, 18.4, 18.8, 19.2, 19.6, 20.0], | ||||
|                         [19.2, 19.6, 20.0, 20.4, 20.8, 21.2], | ||||
|                         [20.4, 20.8, 21.2, 21.6, 22.0, 22.4], | ||||
|                         [21.6, 22.0, 22.4, 22.8, 23.2, 23.6], | ||||
|                         [22.8, 23.2, 23.6, 24.0, 24.4, 24.8], | ||||
|                         [24.0, 24.4, 24.8, 25.2, 25.6, 26.0], | ||||
|                     ], | ||||
|                     [ | ||||
|                         [27.0, 27.4, 27.8, 28.2, 28.6, 29.0], | ||||
|                         [28.2, 28.6, 29.0, 29.4, 29.8, 30.2], | ||||
|                         [29.4, 29.8, 30.2, 30.6, 31.0, 31.4], | ||||
|                         [30.6, 31.0, 31.4, 31.8, 32.2, 32.6], | ||||
|                         [31.8, 32.2, 32.6, 33.0, 33.4, 33.8], | ||||
|                         [33.0, 33.4, 33.8, 34.2, 34.6, 35.0], | ||||
|                     ], | ||||
|                 ], | ||||
|             ] | ||||
|         ).transpose((0, 2, 3, 1)) | ||||
|         self.assertTrue(np.allclose(upsample_nearest(x), expected_nearest)) | ||||
|         self.assertTrue(np.allclose(upsample_bilinear(x), expected_bilinear)) | ||||
|  | ||||
|         # Test different height and width scale_factor | ||||
|         b, h, w, c = 1, 2, 2, 2 | ||||
|         x = mx.arange(b * h * w * c).reshape((b, c, h, w)).transpose((0, 2, 3, 1)) | ||||
|         upsample_nearest = nn.Upsample( | ||||
|             scale_factor=(2, 3), mode="nearest", align_corners=True | ||||
|         ) | ||||
|         upsample_bilinear = nn.Upsample( | ||||
|             scale_factor=(2, 3), mode="linear", align_corners=True | ||||
|         ) | ||||
|  | ||||
|         expected_nearest = mx.array( | ||||
|             [ | ||||
|                 [ | ||||
|                     [ | ||||
|                         [0, 0, 0, 1, 1, 1], | ||||
|                         [0, 0, 0, 1, 1, 1], | ||||
|                         [2, 2, 2, 3, 3, 3], | ||||
|                         [2, 2, 2, 3, 3, 3], | ||||
|                     ], | ||||
|                     [ | ||||
|                         [4, 4, 4, 5, 5, 5], | ||||
|                         [4, 4, 4, 5, 5, 5], | ||||
|                         [6, 6, 6, 7, 7, 7], | ||||
|                         [6, 6, 6, 7, 7, 7], | ||||
|                     ], | ||||
|                 ] | ||||
|             ] | ||||
|         ).transpose((0, 2, 3, 1)) | ||||
|         expected_bilinear = mx.array( | ||||
|             [ | ||||
|                 [ | ||||
|                     [ | ||||
|                         [0, 0.2, 0.4, 0.6, 0.8, 1], | ||||
|                         [0.666667, 0.866667, 1.06667, 1.26667, 1.46667, 1.66667], | ||||
|                         [1.33333, 1.53333, 1.73333, 1.93333, 2.13333, 2.33333], | ||||
|                         [2, 2.2, 2.4, 2.6, 2.8, 3], | ||||
|                     ], | ||||
|                     [ | ||||
|                         [4, 4.2, 4.4, 4.6, 4.8, 5], | ||||
|                         [4.66667, 4.86667, 5.06667, 5.26667, 5.46667, 5.66667], | ||||
|                         [5.33333, 5.53333, 5.73333, 5.93333, 6.13333, 6.33333], | ||||
|                         [6, 6.2, 6.4, 6.6, 6.8, 7], | ||||
|                     ], | ||||
|                 ] | ||||
|             ] | ||||
|         ).transpose((0, 2, 3, 1)) | ||||
|         self.assertTrue(np.allclose(upsample_nearest(x), expected_nearest)) | ||||
|         self.assertTrue(np.allclose(upsample_bilinear(x), expected_bilinear)) | ||||
|  | ||||
|         # Test repr | ||||
|         self.assertEqual( | ||||
|             str(nn.Upsample(scale_factor=2)), | ||||
|             "Upsample(scale_factor=2.0, mode='nearest', align_corners=False)", | ||||
|         ) | ||||
|         self.assertEqual( | ||||
|             str(nn.Upsample(scale_factor=(2, 3))), | ||||
|             "Upsample(scale_factor=(2.0, 3.0), mode='nearest', align_corners=False)", | ||||
|         ) | ||||
|  | ||||
|     def test_pooling(self): | ||||
|         # Test 1d pooling | ||||
|         x = mx.array( | ||||
|   | ||||
		Reference in New Issue
	
	Block a user