Upsample with bicubic interpolation (#967)

This commit is contained in:
Shiyu 2024-04-11 06:47:22 +08:00 committed by GitHub
parent 99abb9eff4
commit 061cf9a4ce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 178 additions and 16 deletions

View File

@ -1,9 +1,9 @@
# Copyright © 2023-2024 Apple Inc.
import operator
from functools import reduce
from functools import partial, reduce
from itertools import product
from typing import Literal, Tuple, Union
from typing import Callable, Literal, Tuple, Union
import mlx.core as mx
from mlx.nn.layers.base import Module
@ -17,7 +17,7 @@ def _scaled_indices(N, scale, align_corners, dim, ndims):
step = 1 / scale
start = ((M - 1) * step - N + 1) / 2
indices = mx.arange(M, dtype=mx.float32) * step - start
indices = mx.clip(indices, 0, N - 1)
shape = [1] * ndims
shape[dim] = -1
@ -30,6 +30,7 @@ def _nearest_indices(N, scale, dim, ndims):
def _linear_indices(N, scale, align_corners, dim, ndims):
indices = _scaled_indices(N, scale, align_corners, dim, ndims)
indices = mx.clip(indices, a_min=0, a_max=N - 1)
indices_l = mx.floor(indices)
indices_r = mx.ceil(indices)
weight = indices - indices_l
@ -41,6 +42,44 @@ def _linear_indices(N, scale, align_corners, dim, ndims):
)
def _cubic_indices(N, scale, align_corners, dim, ndims):
indices = _scaled_indices(N, scale, align_corners, dim, ndims)
indices_l1 = mx.floor(indices)
indices_r1 = mx.floor(indices + 1)
indices_l2 = indices_l1 - 1
indices_r2 = indices_r1 + 1
@partial(mx.compile, shapeless=True)
def _get_weight(ind, grid, dist):
# PyTorch uses -0.5 for antialiasing=true (compatibility with PIL)
# and uses -0.75 for antialiasing=false (compatibility with OpenCV)
a = -0.75
x = mx.abs(ind - grid)
if dist == 1:
weight = ((a + 2.0) * x - (a + 3.0)) * x * x + 1
else:
weight = (((x - 5) * x + 8) * x - 4) * a
return weight
weight_l1 = _get_weight(indices, indices_l1, dist=1)[..., None]
weight_r1 = _get_weight(indices, indices_r1, dist=1)[..., None]
weight_l2 = _get_weight(indices, indices_l2, dist=2)[..., None]
weight_r2 = _get_weight(indices, indices_r2, dist=2)[..., None]
# padding with border value
indices_l1 = mx.clip(indices_l1, a_min=0, a_max=N - 1)
indices_r1 = mx.clip(indices_r1, a_min=0, a_max=N - 1)
indices_l2 = mx.clip(indices_l2, a_min=0, a_max=N - 1)
indices_r2 = mx.clip(indices_r2, a_min=0, a_max=N - 1)
return (
(indices_l1.astype(mx.int32), weight_l1),
(indices_r1.astype(mx.int32), weight_r1),
(indices_l2.astype(mx.int32), weight_l2),
(indices_r2.astype(mx.int32), weight_r2),
)
def upsample_nearest(x: mx.array, scale_factor: Tuple):
dims = x.ndim - 2
if dims != len(scale_factor):
@ -71,7 +110,9 @@ def upsample_nearest(x: mx.array, scale_factor: Tuple):
return x[indices]
def upsample_linear(x: mx.array, scale_factor: Tuple, align_corners: bool = False):
def _interpolate(
x: mx.array, scale_factor: Tuple, indices_fn: Callable, align_corners: bool = False
):
dims = x.ndim - 2
if dims != len(scale_factor):
raise ValueError("A scale needs to be provided for each spatial dimension")
@ -81,7 +122,7 @@ def upsample_linear(x: mx.array, scale_factor: Tuple, align_corners: bool = Fals
# Compute the sampling grid
indices = []
for i, (n, s) in enumerate(zip(N, scale_factor)):
indices.append(_linear_indices(n, s, align_corners, i, dims))
indices.append(indices_fn(n, s, align_corners, i, dims))
# Sample and compute the weights
samples = []
@ -95,6 +136,24 @@ def upsample_linear(x: mx.array, scale_factor: Tuple, align_corners: bool = Fals
return sum(wi * xi for wi, xi in zip(weights, samples))
def upsample_linear(x: mx.array, scale_factor: Tuple, align_corners: bool = False):
return _interpolate(
x=x,
scale_factor=scale_factor,
indices_fn=_linear_indices,
align_corners=align_corners,
)
def upsample_cubic(x: mx.array, scale_factor: Tuple, align_corners: bool = False):
return _interpolate(
x=x,
scale_factor=scale_factor,
indices_fn=_cubic_indices,
align_corners=align_corners,
)
class Upsample(Module):
r"""Upsample the input signal spatially.
@ -105,13 +164,14 @@ class Upsample(Module):
For example, an audio signal would be 3D with 1 spatial dimension, an image
4D with 2 and so on and so forth.
There are two upsampling algorithms implemented nearest neighbor upsampling
and linear interpolation. Both can be applied to any number of spatial
dimensions and the linear interpolation will be bilinear, trilinear etc
when applied to more than one spatial dimension.
There are three upsampling algorithms implemented nearest neighbor upsampling,
linear interpolation, and cubic interpolation. All can be applied to any number
of spatial dimensions. The linear interpolation will be bilinear, trilinear etc
when applied to more than one spatial dimension. And cubic interpolation will be
bicubic when there are 2 spatial dimensions.
.. note::
When using one of the linear interpolation modes the ``align_corners``
When using one of the linear or cubic interpolation modes the ``align_corners``
argument changes how the corners are treated in the input image. If
``align_corners=True`` then the top and left edge of the input and
output will be matching as will the bottom right edge.
@ -121,10 +181,10 @@ class Upsample(Module):
If a ``float`` is provided, it is the multiplier for all spatial dimensions.
Otherwise, the number of scale factors provided must match the
number of spatial dimensions.
mode (str, optional): The upsampling algorithm, either ``"nearest"`` or
``"linear"``. Default: ``"nearest"``.
mode (str, optional): The upsampling algorithm, either ``"nearest"``,
``"linear"`` or ``"cubic"``. Default: ``"nearest"``.
align_corners (bool, optional): Changes the way the corners are treated
during ``"linear"`` upsampling. See the note above and the
during ``"linear"`` and ``"cubic"`` upsampling. See the note above and the
examples below for more details. Default: ``False``.
Examples:
@ -163,7 +223,7 @@ class Upsample(Module):
align_corners: bool = False,
):
super().__init__()
if mode not in ["nearest", "linear"]:
if mode not in ["nearest", "linear", "cubic"]:
raise ValueError(f"[Upsample] Got unsupported upsampling algorithm: {mode}")
if isinstance(scale_factor, (list, tuple)):
self.scale_factor = tuple(map(float, scale_factor))
@ -200,6 +260,9 @@ class Upsample(Module):
if self.mode == "nearest":
return upsample_nearest(x, scale_factor)
else:
elif self.mode == "linear":
return upsample_linear(x, scale_factor, self.align_corners)
elif self.mode == "cubic":
return upsample_cubic(x, scale_factor, self.align_corners)
else:
raise Exception(f"Unknown interpolation mode: {self.mode}")

View File

@ -0,0 +1,99 @@
# Copyright © 2023-2024 Apple Inc.
import unittest
import mlx.core as mx
import mlx.nn as nn
import mlx_tests
import numpy as np
try:
import torch
import torch.nn.functional as F
has_torch = True
except ImportError as e:
has_torch = False
class TestUpsample(mlx_tests.MLXTestCase):
@unittest.skipIf(not has_torch, "requires Torch")
def test_torch_upsample(self):
def run_upsample(
N,
C,
idim,
scale_factor,
mode,
align_corner,
dtype="float32",
atol=1e-5,
):
with self.subTest(
N=N,
C=C,
idim=idim,
scale_factor=scale_factor,
mode=mode,
align_corner=align_corner,
):
np_dtype = getattr(np, dtype)
np.random.seed(0)
iH, iW = idim
in_np = np.random.normal(-1.0, 1.0, (N, iH, iW, C)).astype(np_dtype)
in_mx = mx.array(in_np)
in_pt = torch.from_numpy(in_np.transpose(0, 3, 1, 2)).to("cpu")
out_mx = nn.Upsample(
scale_factor=scale_factor,
mode=mode,
align_corners=align_corner,
)(in_mx)
mode_pt = {
"linear": "bilinear",
"cubic": "bicubic",
}[mode]
out_pt = F.interpolate(
in_pt,
scale_factor=scale_factor,
mode=mode_pt,
align_corners=align_corner,
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True)
self.assertEqual(out_pt.shape, out_mx.shape)
self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
for dtype in ("float32",):
for N, C in ((1, 1), (2, 3)):
# only test cases in which target sizes are intergers
# if not, there will be numerical difference between mlx
# and torch due to different indices selection.
for idim, scale_factor in (
((2, 2), (1.0, 1.0)),
((2, 2), (1.5, 1.5)),
((2, 2), (2.0, 2.0)),
((4, 4), (0.5, 0.5)),
((7, 7), (2.0, 2.0)),
((10, 10), (0.2, 0.2)),
((11, 21), (3.0, 3.0)),
((11, 21), (3.0, 2.0)),
):
# only test linear and cubic interpolation
# there will be numerical difference in nearest
# due to different indices selection.
for mode in ("cubic", "linear"):
for align_corner in (False, True):
run_upsample(
N,
C,
idim,
scale_factor,
mode,
align_corner,
dtype=dtype,
)
if __name__ == "__main__":
unittest.main()