mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Upsample with bicubic interpolation (#967)
This commit is contained in:
parent
99abb9eff4
commit
061cf9a4ce
@ -1,9 +1,9 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
import operator
|
import operator
|
||||||
from functools import reduce
|
from functools import partial, reduce
|
||||||
from itertools import product
|
from itertools import product
|
||||||
from typing import Literal, Tuple, Union
|
from typing import Callable, Literal, Tuple, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from mlx.nn.layers.base import Module
|
from mlx.nn.layers.base import Module
|
||||||
@ -17,7 +17,7 @@ def _scaled_indices(N, scale, align_corners, dim, ndims):
|
|||||||
step = 1 / scale
|
step = 1 / scale
|
||||||
start = ((M - 1) * step - N + 1) / 2
|
start = ((M - 1) * step - N + 1) / 2
|
||||||
indices = mx.arange(M, dtype=mx.float32) * step - start
|
indices = mx.arange(M, dtype=mx.float32) * step - start
|
||||||
indices = mx.clip(indices, 0, N - 1)
|
|
||||||
shape = [1] * ndims
|
shape = [1] * ndims
|
||||||
shape[dim] = -1
|
shape[dim] = -1
|
||||||
|
|
||||||
@ -30,6 +30,7 @@ def _nearest_indices(N, scale, dim, ndims):
|
|||||||
|
|
||||||
def _linear_indices(N, scale, align_corners, dim, ndims):
|
def _linear_indices(N, scale, align_corners, dim, ndims):
|
||||||
indices = _scaled_indices(N, scale, align_corners, dim, ndims)
|
indices = _scaled_indices(N, scale, align_corners, dim, ndims)
|
||||||
|
indices = mx.clip(indices, a_min=0, a_max=N - 1)
|
||||||
indices_l = mx.floor(indices)
|
indices_l = mx.floor(indices)
|
||||||
indices_r = mx.ceil(indices)
|
indices_r = mx.ceil(indices)
|
||||||
weight = indices - indices_l
|
weight = indices - indices_l
|
||||||
@ -41,6 +42,44 @@ def _linear_indices(N, scale, align_corners, dim, ndims):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _cubic_indices(N, scale, align_corners, dim, ndims):
|
||||||
|
indices = _scaled_indices(N, scale, align_corners, dim, ndims)
|
||||||
|
indices_l1 = mx.floor(indices)
|
||||||
|
indices_r1 = mx.floor(indices + 1)
|
||||||
|
indices_l2 = indices_l1 - 1
|
||||||
|
indices_r2 = indices_r1 + 1
|
||||||
|
|
||||||
|
@partial(mx.compile, shapeless=True)
|
||||||
|
def _get_weight(ind, grid, dist):
|
||||||
|
# PyTorch uses -0.5 for antialiasing=true (compatibility with PIL)
|
||||||
|
# and uses -0.75 for antialiasing=false (compatibility with OpenCV)
|
||||||
|
a = -0.75
|
||||||
|
x = mx.abs(ind - grid)
|
||||||
|
if dist == 1:
|
||||||
|
weight = ((a + 2.0) * x - (a + 3.0)) * x * x + 1
|
||||||
|
else:
|
||||||
|
weight = (((x - 5) * x + 8) * x - 4) * a
|
||||||
|
return weight
|
||||||
|
|
||||||
|
weight_l1 = _get_weight(indices, indices_l1, dist=1)[..., None]
|
||||||
|
weight_r1 = _get_weight(indices, indices_r1, dist=1)[..., None]
|
||||||
|
weight_l2 = _get_weight(indices, indices_l2, dist=2)[..., None]
|
||||||
|
weight_r2 = _get_weight(indices, indices_r2, dist=2)[..., None]
|
||||||
|
|
||||||
|
# padding with border value
|
||||||
|
indices_l1 = mx.clip(indices_l1, a_min=0, a_max=N - 1)
|
||||||
|
indices_r1 = mx.clip(indices_r1, a_min=0, a_max=N - 1)
|
||||||
|
indices_l2 = mx.clip(indices_l2, a_min=0, a_max=N - 1)
|
||||||
|
indices_r2 = mx.clip(indices_r2, a_min=0, a_max=N - 1)
|
||||||
|
|
||||||
|
return (
|
||||||
|
(indices_l1.astype(mx.int32), weight_l1),
|
||||||
|
(indices_r1.astype(mx.int32), weight_r1),
|
||||||
|
(indices_l2.astype(mx.int32), weight_l2),
|
||||||
|
(indices_r2.astype(mx.int32), weight_r2),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def upsample_nearest(x: mx.array, scale_factor: Tuple):
|
def upsample_nearest(x: mx.array, scale_factor: Tuple):
|
||||||
dims = x.ndim - 2
|
dims = x.ndim - 2
|
||||||
if dims != len(scale_factor):
|
if dims != len(scale_factor):
|
||||||
@ -71,7 +110,9 @@ def upsample_nearest(x: mx.array, scale_factor: Tuple):
|
|||||||
return x[indices]
|
return x[indices]
|
||||||
|
|
||||||
|
|
||||||
def upsample_linear(x: mx.array, scale_factor: Tuple, align_corners: bool = False):
|
def _interpolate(
|
||||||
|
x: mx.array, scale_factor: Tuple, indices_fn: Callable, align_corners: bool = False
|
||||||
|
):
|
||||||
dims = x.ndim - 2
|
dims = x.ndim - 2
|
||||||
if dims != len(scale_factor):
|
if dims != len(scale_factor):
|
||||||
raise ValueError("A scale needs to be provided for each spatial dimension")
|
raise ValueError("A scale needs to be provided for each spatial dimension")
|
||||||
@ -81,7 +122,7 @@ def upsample_linear(x: mx.array, scale_factor: Tuple, align_corners: bool = Fals
|
|||||||
# Compute the sampling grid
|
# Compute the sampling grid
|
||||||
indices = []
|
indices = []
|
||||||
for i, (n, s) in enumerate(zip(N, scale_factor)):
|
for i, (n, s) in enumerate(zip(N, scale_factor)):
|
||||||
indices.append(_linear_indices(n, s, align_corners, i, dims))
|
indices.append(indices_fn(n, s, align_corners, i, dims))
|
||||||
|
|
||||||
# Sample and compute the weights
|
# Sample and compute the weights
|
||||||
samples = []
|
samples = []
|
||||||
@ -95,6 +136,24 @@ def upsample_linear(x: mx.array, scale_factor: Tuple, align_corners: bool = Fals
|
|||||||
return sum(wi * xi for wi, xi in zip(weights, samples))
|
return sum(wi * xi for wi, xi in zip(weights, samples))
|
||||||
|
|
||||||
|
|
||||||
|
def upsample_linear(x: mx.array, scale_factor: Tuple, align_corners: bool = False):
|
||||||
|
return _interpolate(
|
||||||
|
x=x,
|
||||||
|
scale_factor=scale_factor,
|
||||||
|
indices_fn=_linear_indices,
|
||||||
|
align_corners=align_corners,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def upsample_cubic(x: mx.array, scale_factor: Tuple, align_corners: bool = False):
|
||||||
|
return _interpolate(
|
||||||
|
x=x,
|
||||||
|
scale_factor=scale_factor,
|
||||||
|
indices_fn=_cubic_indices,
|
||||||
|
align_corners=align_corners,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Upsample(Module):
|
class Upsample(Module):
|
||||||
r"""Upsample the input signal spatially.
|
r"""Upsample the input signal spatially.
|
||||||
|
|
||||||
@ -105,13 +164,14 @@ class Upsample(Module):
|
|||||||
For example, an audio signal would be 3D with 1 spatial dimension, an image
|
For example, an audio signal would be 3D with 1 spatial dimension, an image
|
||||||
4D with 2 and so on and so forth.
|
4D with 2 and so on and so forth.
|
||||||
|
|
||||||
There are two upsampling algorithms implemented nearest neighbor upsampling
|
There are three upsampling algorithms implemented nearest neighbor upsampling,
|
||||||
and linear interpolation. Both can be applied to any number of spatial
|
linear interpolation, and cubic interpolation. All can be applied to any number
|
||||||
dimensions and the linear interpolation will be bilinear, trilinear etc
|
of spatial dimensions. The linear interpolation will be bilinear, trilinear etc
|
||||||
when applied to more than one spatial dimension.
|
when applied to more than one spatial dimension. And cubic interpolation will be
|
||||||
|
bicubic when there are 2 spatial dimensions.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
When using one of the linear interpolation modes the ``align_corners``
|
When using one of the linear or cubic interpolation modes the ``align_corners``
|
||||||
argument changes how the corners are treated in the input image. If
|
argument changes how the corners are treated in the input image. If
|
||||||
``align_corners=True`` then the top and left edge of the input and
|
``align_corners=True`` then the top and left edge of the input and
|
||||||
output will be matching as will the bottom right edge.
|
output will be matching as will the bottom right edge.
|
||||||
@ -121,10 +181,10 @@ class Upsample(Module):
|
|||||||
If a ``float`` is provided, it is the multiplier for all spatial dimensions.
|
If a ``float`` is provided, it is the multiplier for all spatial dimensions.
|
||||||
Otherwise, the number of scale factors provided must match the
|
Otherwise, the number of scale factors provided must match the
|
||||||
number of spatial dimensions.
|
number of spatial dimensions.
|
||||||
mode (str, optional): The upsampling algorithm, either ``"nearest"`` or
|
mode (str, optional): The upsampling algorithm, either ``"nearest"``,
|
||||||
``"linear"``. Default: ``"nearest"``.
|
``"linear"`` or ``"cubic"``. Default: ``"nearest"``.
|
||||||
align_corners (bool, optional): Changes the way the corners are treated
|
align_corners (bool, optional): Changes the way the corners are treated
|
||||||
during ``"linear"`` upsampling. See the note above and the
|
during ``"linear"`` and ``"cubic"`` upsampling. See the note above and the
|
||||||
examples below for more details. Default: ``False``.
|
examples below for more details. Default: ``False``.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
@ -163,7 +223,7 @@ class Upsample(Module):
|
|||||||
align_corners: bool = False,
|
align_corners: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if mode not in ["nearest", "linear"]:
|
if mode not in ["nearest", "linear", "cubic"]:
|
||||||
raise ValueError(f"[Upsample] Got unsupported upsampling algorithm: {mode}")
|
raise ValueError(f"[Upsample] Got unsupported upsampling algorithm: {mode}")
|
||||||
if isinstance(scale_factor, (list, tuple)):
|
if isinstance(scale_factor, (list, tuple)):
|
||||||
self.scale_factor = tuple(map(float, scale_factor))
|
self.scale_factor = tuple(map(float, scale_factor))
|
||||||
@ -200,6 +260,9 @@ class Upsample(Module):
|
|||||||
|
|
||||||
if self.mode == "nearest":
|
if self.mode == "nearest":
|
||||||
return upsample_nearest(x, scale_factor)
|
return upsample_nearest(x, scale_factor)
|
||||||
|
elif self.mode == "linear":
|
||||||
else:
|
|
||||||
return upsample_linear(x, scale_factor, self.align_corners)
|
return upsample_linear(x, scale_factor, self.align_corners)
|
||||||
|
elif self.mode == "cubic":
|
||||||
|
return upsample_cubic(x, scale_factor, self.align_corners)
|
||||||
|
else:
|
||||||
|
raise Exception(f"Unknown interpolation mode: {self.mode}")
|
||||||
|
99
python/tests/test_upsample.py
Normal file
99
python/tests/test_upsample.py
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
import mlx_tests
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
has_torch = True
|
||||||
|
except ImportError as e:
|
||||||
|
has_torch = False
|
||||||
|
|
||||||
|
|
||||||
|
class TestUpsample(mlx_tests.MLXTestCase):
|
||||||
|
@unittest.skipIf(not has_torch, "requires Torch")
|
||||||
|
def test_torch_upsample(self):
|
||||||
|
def run_upsample(
|
||||||
|
N,
|
||||||
|
C,
|
||||||
|
idim,
|
||||||
|
scale_factor,
|
||||||
|
mode,
|
||||||
|
align_corner,
|
||||||
|
dtype="float32",
|
||||||
|
atol=1e-5,
|
||||||
|
):
|
||||||
|
with self.subTest(
|
||||||
|
N=N,
|
||||||
|
C=C,
|
||||||
|
idim=idim,
|
||||||
|
scale_factor=scale_factor,
|
||||||
|
mode=mode,
|
||||||
|
align_corner=align_corner,
|
||||||
|
):
|
||||||
|
np_dtype = getattr(np, dtype)
|
||||||
|
np.random.seed(0)
|
||||||
|
iH, iW = idim
|
||||||
|
in_np = np.random.normal(-1.0, 1.0, (N, iH, iW, C)).astype(np_dtype)
|
||||||
|
|
||||||
|
in_mx = mx.array(in_np)
|
||||||
|
in_pt = torch.from_numpy(in_np.transpose(0, 3, 1, 2)).to("cpu")
|
||||||
|
|
||||||
|
out_mx = nn.Upsample(
|
||||||
|
scale_factor=scale_factor,
|
||||||
|
mode=mode,
|
||||||
|
align_corners=align_corner,
|
||||||
|
)(in_mx)
|
||||||
|
mode_pt = {
|
||||||
|
"linear": "bilinear",
|
||||||
|
"cubic": "bicubic",
|
||||||
|
}[mode]
|
||||||
|
out_pt = F.interpolate(
|
||||||
|
in_pt,
|
||||||
|
scale_factor=scale_factor,
|
||||||
|
mode=mode_pt,
|
||||||
|
align_corners=align_corner,
|
||||||
|
)
|
||||||
|
out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True)
|
||||||
|
self.assertEqual(out_pt.shape, out_mx.shape)
|
||||||
|
self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
|
||||||
|
|
||||||
|
for dtype in ("float32",):
|
||||||
|
for N, C in ((1, 1), (2, 3)):
|
||||||
|
# only test cases in which target sizes are intergers
|
||||||
|
# if not, there will be numerical difference between mlx
|
||||||
|
# and torch due to different indices selection.
|
||||||
|
for idim, scale_factor in (
|
||||||
|
((2, 2), (1.0, 1.0)),
|
||||||
|
((2, 2), (1.5, 1.5)),
|
||||||
|
((2, 2), (2.0, 2.0)),
|
||||||
|
((4, 4), (0.5, 0.5)),
|
||||||
|
((7, 7), (2.0, 2.0)),
|
||||||
|
((10, 10), (0.2, 0.2)),
|
||||||
|
((11, 21), (3.0, 3.0)),
|
||||||
|
((11, 21), (3.0, 2.0)),
|
||||||
|
):
|
||||||
|
# only test linear and cubic interpolation
|
||||||
|
# there will be numerical difference in nearest
|
||||||
|
# due to different indices selection.
|
||||||
|
for mode in ("cubic", "linear"):
|
||||||
|
for align_corner in (False, True):
|
||||||
|
run_upsample(
|
||||||
|
N,
|
||||||
|
C,
|
||||||
|
idim,
|
||||||
|
scale_factor,
|
||||||
|
mode,
|
||||||
|
align_corner,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Loading…
Reference in New Issue
Block a user