diff --git a/python/mlx/nn/layers/upsample.py b/python/mlx/nn/layers/upsample.py index aac70e24a..6f813ba3f 100644 --- a/python/mlx/nn/layers/upsample.py +++ b/python/mlx/nn/layers/upsample.py @@ -1,9 +1,9 @@ # Copyright © 2023-2024 Apple Inc. import operator -from functools import reduce +from functools import partial, reduce from itertools import product -from typing import Literal, Tuple, Union +from typing import Callable, Literal, Tuple, Union import mlx.core as mx from mlx.nn.layers.base import Module @@ -17,7 +17,7 @@ def _scaled_indices(N, scale, align_corners, dim, ndims): step = 1 / scale start = ((M - 1) * step - N + 1) / 2 indices = mx.arange(M, dtype=mx.float32) * step - start - indices = mx.clip(indices, 0, N - 1) + shape = [1] * ndims shape[dim] = -1 @@ -30,6 +30,7 @@ def _nearest_indices(N, scale, dim, ndims): def _linear_indices(N, scale, align_corners, dim, ndims): indices = _scaled_indices(N, scale, align_corners, dim, ndims) + indices = mx.clip(indices, a_min=0, a_max=N - 1) indices_l = mx.floor(indices) indices_r = mx.ceil(indices) weight = indices - indices_l @@ -41,6 +42,44 @@ def _linear_indices(N, scale, align_corners, dim, ndims): ) +def _cubic_indices(N, scale, align_corners, dim, ndims): + indices = _scaled_indices(N, scale, align_corners, dim, ndims) + indices_l1 = mx.floor(indices) + indices_r1 = mx.floor(indices + 1) + indices_l2 = indices_l1 - 1 + indices_r2 = indices_r1 + 1 + + @partial(mx.compile, shapeless=True) + def _get_weight(ind, grid, dist): + # PyTorch uses -0.5 for antialiasing=true (compatibility with PIL) + # and uses -0.75 for antialiasing=false (compatibility with OpenCV) + a = -0.75 + x = mx.abs(ind - grid) + if dist == 1: + weight = ((a + 2.0) * x - (a + 3.0)) * x * x + 1 + else: + weight = (((x - 5) * x + 8) * x - 4) * a + return weight + + weight_l1 = _get_weight(indices, indices_l1, dist=1)[..., None] + weight_r1 = _get_weight(indices, indices_r1, dist=1)[..., None] + weight_l2 = _get_weight(indices, indices_l2, dist=2)[..., None] + weight_r2 = _get_weight(indices, indices_r2, dist=2)[..., None] + + # padding with border value + indices_l1 = mx.clip(indices_l1, a_min=0, a_max=N - 1) + indices_r1 = mx.clip(indices_r1, a_min=0, a_max=N - 1) + indices_l2 = mx.clip(indices_l2, a_min=0, a_max=N - 1) + indices_r2 = mx.clip(indices_r2, a_min=0, a_max=N - 1) + + return ( + (indices_l1.astype(mx.int32), weight_l1), + (indices_r1.astype(mx.int32), weight_r1), + (indices_l2.astype(mx.int32), weight_l2), + (indices_r2.astype(mx.int32), weight_r2), + ) + + def upsample_nearest(x: mx.array, scale_factor: Tuple): dims = x.ndim - 2 if dims != len(scale_factor): @@ -71,7 +110,9 @@ def upsample_nearest(x: mx.array, scale_factor: Tuple): return x[indices] -def upsample_linear(x: mx.array, scale_factor: Tuple, align_corners: bool = False): +def _interpolate( + x: mx.array, scale_factor: Tuple, indices_fn: Callable, align_corners: bool = False +): dims = x.ndim - 2 if dims != len(scale_factor): raise ValueError("A scale needs to be provided for each spatial dimension") @@ -81,7 +122,7 @@ def upsample_linear(x: mx.array, scale_factor: Tuple, align_corners: bool = Fals # Compute the sampling grid indices = [] for i, (n, s) in enumerate(zip(N, scale_factor)): - indices.append(_linear_indices(n, s, align_corners, i, dims)) + indices.append(indices_fn(n, s, align_corners, i, dims)) # Sample and compute the weights samples = [] @@ -95,6 +136,24 @@ def upsample_linear(x: mx.array, scale_factor: Tuple, align_corners: bool = Fals return sum(wi * xi for wi, xi in zip(weights, samples)) +def upsample_linear(x: mx.array, scale_factor: Tuple, align_corners: bool = False): + return _interpolate( + x=x, + scale_factor=scale_factor, + indices_fn=_linear_indices, + align_corners=align_corners, + ) + + +def upsample_cubic(x: mx.array, scale_factor: Tuple, align_corners: bool = False): + return _interpolate( + x=x, + scale_factor=scale_factor, + indices_fn=_cubic_indices, + align_corners=align_corners, + ) + + class Upsample(Module): r"""Upsample the input signal spatially. @@ -105,13 +164,14 @@ class Upsample(Module): For example, an audio signal would be 3D with 1 spatial dimension, an image 4D with 2 and so on and so forth. - There are two upsampling algorithms implemented nearest neighbor upsampling - and linear interpolation. Both can be applied to any number of spatial - dimensions and the linear interpolation will be bilinear, trilinear etc - when applied to more than one spatial dimension. + There are three upsampling algorithms implemented nearest neighbor upsampling, + linear interpolation, and cubic interpolation. All can be applied to any number + of spatial dimensions. The linear interpolation will be bilinear, trilinear etc + when applied to more than one spatial dimension. And cubic interpolation will be + bicubic when there are 2 spatial dimensions. .. note:: - When using one of the linear interpolation modes the ``align_corners`` + When using one of the linear or cubic interpolation modes the ``align_corners`` argument changes how the corners are treated in the input image. If ``align_corners=True`` then the top and left edge of the input and output will be matching as will the bottom right edge. @@ -121,10 +181,10 @@ class Upsample(Module): If a ``float`` is provided, it is the multiplier for all spatial dimensions. Otherwise, the number of scale factors provided must match the number of spatial dimensions. - mode (str, optional): The upsampling algorithm, either ``"nearest"`` or - ``"linear"``. Default: ``"nearest"``. + mode (str, optional): The upsampling algorithm, either ``"nearest"``, + ``"linear"`` or ``"cubic"``. Default: ``"nearest"``. align_corners (bool, optional): Changes the way the corners are treated - during ``"linear"`` upsampling. See the note above and the + during ``"linear"`` and ``"cubic"`` upsampling. See the note above and the examples below for more details. Default: ``False``. Examples: @@ -163,7 +223,7 @@ class Upsample(Module): align_corners: bool = False, ): super().__init__() - if mode not in ["nearest", "linear"]: + if mode not in ["nearest", "linear", "cubic"]: raise ValueError(f"[Upsample] Got unsupported upsampling algorithm: {mode}") if isinstance(scale_factor, (list, tuple)): self.scale_factor = tuple(map(float, scale_factor)) @@ -200,6 +260,9 @@ class Upsample(Module): if self.mode == "nearest": return upsample_nearest(x, scale_factor) - - else: + elif self.mode == "linear": return upsample_linear(x, scale_factor, self.align_corners) + elif self.mode == "cubic": + return upsample_cubic(x, scale_factor, self.align_corners) + else: + raise Exception(f"Unknown interpolation mode: {self.mode}") diff --git a/python/tests/test_upsample.py b/python/tests/test_upsample.py new file mode 100644 index 000000000..402c7b0ca --- /dev/null +++ b/python/tests/test_upsample.py @@ -0,0 +1,99 @@ +# Copyright © 2023-2024 Apple Inc. + +import unittest + +import mlx.core as mx +import mlx.nn as nn +import mlx_tests +import numpy as np + +try: + import torch + import torch.nn.functional as F + + has_torch = True +except ImportError as e: + has_torch = False + + +class TestUpsample(mlx_tests.MLXTestCase): + @unittest.skipIf(not has_torch, "requires Torch") + def test_torch_upsample(self): + def run_upsample( + N, + C, + idim, + scale_factor, + mode, + align_corner, + dtype="float32", + atol=1e-5, + ): + with self.subTest( + N=N, + C=C, + idim=idim, + scale_factor=scale_factor, + mode=mode, + align_corner=align_corner, + ): + np_dtype = getattr(np, dtype) + np.random.seed(0) + iH, iW = idim + in_np = np.random.normal(-1.0, 1.0, (N, iH, iW, C)).astype(np_dtype) + + in_mx = mx.array(in_np) + in_pt = torch.from_numpy(in_np.transpose(0, 3, 1, 2)).to("cpu") + + out_mx = nn.Upsample( + scale_factor=scale_factor, + mode=mode, + align_corners=align_corner, + )(in_mx) + mode_pt = { + "linear": "bilinear", + "cubic": "bicubic", + }[mode] + out_pt = F.interpolate( + in_pt, + scale_factor=scale_factor, + mode=mode_pt, + align_corners=align_corner, + ) + out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True) + self.assertEqual(out_pt.shape, out_mx.shape) + self.assertTrue(np.allclose(out_pt, out_mx, atol=atol)) + + for dtype in ("float32",): + for N, C in ((1, 1), (2, 3)): + # only test cases in which target sizes are intergers + # if not, there will be numerical difference between mlx + # and torch due to different indices selection. + for idim, scale_factor in ( + ((2, 2), (1.0, 1.0)), + ((2, 2), (1.5, 1.5)), + ((2, 2), (2.0, 2.0)), + ((4, 4), (0.5, 0.5)), + ((7, 7), (2.0, 2.0)), + ((10, 10), (0.2, 0.2)), + ((11, 21), (3.0, 3.0)), + ((11, 21), (3.0, 2.0)), + ): + # only test linear and cubic interpolation + # there will be numerical difference in nearest + # due to different indices selection. + for mode in ("cubic", "linear"): + for align_corner in (False, True): + run_upsample( + N, + C, + idim, + scale_factor, + mode, + align_corner, + dtype=dtype, + ) + + +if __name__ == "__main__": + unittest.main()