mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Upsample with bicubic interpolation (#967)
This commit is contained in:
parent
99abb9eff4
commit
061cf9a4ce
@ -1,9 +1,9 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import operator
|
||||
from functools import reduce
|
||||
from functools import partial, reduce
|
||||
from itertools import product
|
||||
from typing import Literal, Tuple, Union
|
||||
from typing import Callable, Literal, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
@ -17,7 +17,7 @@ def _scaled_indices(N, scale, align_corners, dim, ndims):
|
||||
step = 1 / scale
|
||||
start = ((M - 1) * step - N + 1) / 2
|
||||
indices = mx.arange(M, dtype=mx.float32) * step - start
|
||||
indices = mx.clip(indices, 0, N - 1)
|
||||
|
||||
shape = [1] * ndims
|
||||
shape[dim] = -1
|
||||
|
||||
@ -30,6 +30,7 @@ def _nearest_indices(N, scale, dim, ndims):
|
||||
|
||||
def _linear_indices(N, scale, align_corners, dim, ndims):
|
||||
indices = _scaled_indices(N, scale, align_corners, dim, ndims)
|
||||
indices = mx.clip(indices, a_min=0, a_max=N - 1)
|
||||
indices_l = mx.floor(indices)
|
||||
indices_r = mx.ceil(indices)
|
||||
weight = indices - indices_l
|
||||
@ -41,6 +42,44 @@ def _linear_indices(N, scale, align_corners, dim, ndims):
|
||||
)
|
||||
|
||||
|
||||
def _cubic_indices(N, scale, align_corners, dim, ndims):
|
||||
indices = _scaled_indices(N, scale, align_corners, dim, ndims)
|
||||
indices_l1 = mx.floor(indices)
|
||||
indices_r1 = mx.floor(indices + 1)
|
||||
indices_l2 = indices_l1 - 1
|
||||
indices_r2 = indices_r1 + 1
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def _get_weight(ind, grid, dist):
|
||||
# PyTorch uses -0.5 for antialiasing=true (compatibility with PIL)
|
||||
# and uses -0.75 for antialiasing=false (compatibility with OpenCV)
|
||||
a = -0.75
|
||||
x = mx.abs(ind - grid)
|
||||
if dist == 1:
|
||||
weight = ((a + 2.0) * x - (a + 3.0)) * x * x + 1
|
||||
else:
|
||||
weight = (((x - 5) * x + 8) * x - 4) * a
|
||||
return weight
|
||||
|
||||
weight_l1 = _get_weight(indices, indices_l1, dist=1)[..., None]
|
||||
weight_r1 = _get_weight(indices, indices_r1, dist=1)[..., None]
|
||||
weight_l2 = _get_weight(indices, indices_l2, dist=2)[..., None]
|
||||
weight_r2 = _get_weight(indices, indices_r2, dist=2)[..., None]
|
||||
|
||||
# padding with border value
|
||||
indices_l1 = mx.clip(indices_l1, a_min=0, a_max=N - 1)
|
||||
indices_r1 = mx.clip(indices_r1, a_min=0, a_max=N - 1)
|
||||
indices_l2 = mx.clip(indices_l2, a_min=0, a_max=N - 1)
|
||||
indices_r2 = mx.clip(indices_r2, a_min=0, a_max=N - 1)
|
||||
|
||||
return (
|
||||
(indices_l1.astype(mx.int32), weight_l1),
|
||||
(indices_r1.astype(mx.int32), weight_r1),
|
||||
(indices_l2.astype(mx.int32), weight_l2),
|
||||
(indices_r2.astype(mx.int32), weight_r2),
|
||||
)
|
||||
|
||||
|
||||
def upsample_nearest(x: mx.array, scale_factor: Tuple):
|
||||
dims = x.ndim - 2
|
||||
if dims != len(scale_factor):
|
||||
@ -71,7 +110,9 @@ def upsample_nearest(x: mx.array, scale_factor: Tuple):
|
||||
return x[indices]
|
||||
|
||||
|
||||
def upsample_linear(x: mx.array, scale_factor: Tuple, align_corners: bool = False):
|
||||
def _interpolate(
|
||||
x: mx.array, scale_factor: Tuple, indices_fn: Callable, align_corners: bool = False
|
||||
):
|
||||
dims = x.ndim - 2
|
||||
if dims != len(scale_factor):
|
||||
raise ValueError("A scale needs to be provided for each spatial dimension")
|
||||
@ -81,7 +122,7 @@ def upsample_linear(x: mx.array, scale_factor: Tuple, align_corners: bool = Fals
|
||||
# Compute the sampling grid
|
||||
indices = []
|
||||
for i, (n, s) in enumerate(zip(N, scale_factor)):
|
||||
indices.append(_linear_indices(n, s, align_corners, i, dims))
|
||||
indices.append(indices_fn(n, s, align_corners, i, dims))
|
||||
|
||||
# Sample and compute the weights
|
||||
samples = []
|
||||
@ -95,6 +136,24 @@ def upsample_linear(x: mx.array, scale_factor: Tuple, align_corners: bool = Fals
|
||||
return sum(wi * xi for wi, xi in zip(weights, samples))
|
||||
|
||||
|
||||
def upsample_linear(x: mx.array, scale_factor: Tuple, align_corners: bool = False):
|
||||
return _interpolate(
|
||||
x=x,
|
||||
scale_factor=scale_factor,
|
||||
indices_fn=_linear_indices,
|
||||
align_corners=align_corners,
|
||||
)
|
||||
|
||||
|
||||
def upsample_cubic(x: mx.array, scale_factor: Tuple, align_corners: bool = False):
|
||||
return _interpolate(
|
||||
x=x,
|
||||
scale_factor=scale_factor,
|
||||
indices_fn=_cubic_indices,
|
||||
align_corners=align_corners,
|
||||
)
|
||||
|
||||
|
||||
class Upsample(Module):
|
||||
r"""Upsample the input signal spatially.
|
||||
|
||||
@ -105,13 +164,14 @@ class Upsample(Module):
|
||||
For example, an audio signal would be 3D with 1 spatial dimension, an image
|
||||
4D with 2 and so on and so forth.
|
||||
|
||||
There are two upsampling algorithms implemented nearest neighbor upsampling
|
||||
and linear interpolation. Both can be applied to any number of spatial
|
||||
dimensions and the linear interpolation will be bilinear, trilinear etc
|
||||
when applied to more than one spatial dimension.
|
||||
There are three upsampling algorithms implemented nearest neighbor upsampling,
|
||||
linear interpolation, and cubic interpolation. All can be applied to any number
|
||||
of spatial dimensions. The linear interpolation will be bilinear, trilinear etc
|
||||
when applied to more than one spatial dimension. And cubic interpolation will be
|
||||
bicubic when there are 2 spatial dimensions.
|
||||
|
||||
.. note::
|
||||
When using one of the linear interpolation modes the ``align_corners``
|
||||
When using one of the linear or cubic interpolation modes the ``align_corners``
|
||||
argument changes how the corners are treated in the input image. If
|
||||
``align_corners=True`` then the top and left edge of the input and
|
||||
output will be matching as will the bottom right edge.
|
||||
@ -121,10 +181,10 @@ class Upsample(Module):
|
||||
If a ``float`` is provided, it is the multiplier for all spatial dimensions.
|
||||
Otherwise, the number of scale factors provided must match the
|
||||
number of spatial dimensions.
|
||||
mode (str, optional): The upsampling algorithm, either ``"nearest"`` or
|
||||
``"linear"``. Default: ``"nearest"``.
|
||||
mode (str, optional): The upsampling algorithm, either ``"nearest"``,
|
||||
``"linear"`` or ``"cubic"``. Default: ``"nearest"``.
|
||||
align_corners (bool, optional): Changes the way the corners are treated
|
||||
during ``"linear"`` upsampling. See the note above and the
|
||||
during ``"linear"`` and ``"cubic"`` upsampling. See the note above and the
|
||||
examples below for more details. Default: ``False``.
|
||||
|
||||
Examples:
|
||||
@ -163,7 +223,7 @@ class Upsample(Module):
|
||||
align_corners: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
if mode not in ["nearest", "linear"]:
|
||||
if mode not in ["nearest", "linear", "cubic"]:
|
||||
raise ValueError(f"[Upsample] Got unsupported upsampling algorithm: {mode}")
|
||||
if isinstance(scale_factor, (list, tuple)):
|
||||
self.scale_factor = tuple(map(float, scale_factor))
|
||||
@ -200,6 +260,9 @@ class Upsample(Module):
|
||||
|
||||
if self.mode == "nearest":
|
||||
return upsample_nearest(x, scale_factor)
|
||||
|
||||
else:
|
||||
elif self.mode == "linear":
|
||||
return upsample_linear(x, scale_factor, self.align_corners)
|
||||
elif self.mode == "cubic":
|
||||
return upsample_cubic(x, scale_factor, self.align_corners)
|
||||
else:
|
||||
raise Exception(f"Unknown interpolation mode: {self.mode}")
|
||||
|
99
python/tests/test_upsample.py
Normal file
99
python/tests/test_upsample.py
Normal file
@ -0,0 +1,99 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx_tests
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
has_torch = True
|
||||
except ImportError as e:
|
||||
has_torch = False
|
||||
|
||||
|
||||
class TestUpsample(mlx_tests.MLXTestCase):
|
||||
@unittest.skipIf(not has_torch, "requires Torch")
|
||||
def test_torch_upsample(self):
|
||||
def run_upsample(
|
||||
N,
|
||||
C,
|
||||
idim,
|
||||
scale_factor,
|
||||
mode,
|
||||
align_corner,
|
||||
dtype="float32",
|
||||
atol=1e-5,
|
||||
):
|
||||
with self.subTest(
|
||||
N=N,
|
||||
C=C,
|
||||
idim=idim,
|
||||
scale_factor=scale_factor,
|
||||
mode=mode,
|
||||
align_corner=align_corner,
|
||||
):
|
||||
np_dtype = getattr(np, dtype)
|
||||
np.random.seed(0)
|
||||
iH, iW = idim
|
||||
in_np = np.random.normal(-1.0, 1.0, (N, iH, iW, C)).astype(np_dtype)
|
||||
|
||||
in_mx = mx.array(in_np)
|
||||
in_pt = torch.from_numpy(in_np.transpose(0, 3, 1, 2)).to("cpu")
|
||||
|
||||
out_mx = nn.Upsample(
|
||||
scale_factor=scale_factor,
|
||||
mode=mode,
|
||||
align_corners=align_corner,
|
||||
)(in_mx)
|
||||
mode_pt = {
|
||||
"linear": "bilinear",
|
||||
"cubic": "bicubic",
|
||||
}[mode]
|
||||
out_pt = F.interpolate(
|
||||
in_pt,
|
||||
scale_factor=scale_factor,
|
||||
mode=mode_pt,
|
||||
align_corners=align_corner,
|
||||
)
|
||||
out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True)
|
||||
self.assertEqual(out_pt.shape, out_mx.shape)
|
||||
self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
|
||||
|
||||
for dtype in ("float32",):
|
||||
for N, C in ((1, 1), (2, 3)):
|
||||
# only test cases in which target sizes are intergers
|
||||
# if not, there will be numerical difference between mlx
|
||||
# and torch due to different indices selection.
|
||||
for idim, scale_factor in (
|
||||
((2, 2), (1.0, 1.0)),
|
||||
((2, 2), (1.5, 1.5)),
|
||||
((2, 2), (2.0, 2.0)),
|
||||
((4, 4), (0.5, 0.5)),
|
||||
((7, 7), (2.0, 2.0)),
|
||||
((10, 10), (0.2, 0.2)),
|
||||
((11, 21), (3.0, 3.0)),
|
||||
((11, 21), (3.0, 2.0)),
|
||||
):
|
||||
# only test linear and cubic interpolation
|
||||
# there will be numerical difference in nearest
|
||||
# due to different indices selection.
|
||||
for mode in ("cubic", "linear"):
|
||||
for align_corner in (False, True):
|
||||
run_upsample(
|
||||
N,
|
||||
C,
|
||||
idim,
|
||||
scale_factor,
|
||||
mode,
|
||||
align_corner,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue
Block a user