mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 19:38:16 +08:00
Upsample with bicubic interpolation (#967)
This commit is contained in:
@@ -1,9 +1,9 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import operator
|
||||
from functools import reduce
|
||||
from functools import partial, reduce
|
||||
from itertools import product
|
||||
from typing import Literal, Tuple, Union
|
||||
from typing import Callable, Literal, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
@@ -17,7 +17,7 @@ def _scaled_indices(N, scale, align_corners, dim, ndims):
|
||||
step = 1 / scale
|
||||
start = ((M - 1) * step - N + 1) / 2
|
||||
indices = mx.arange(M, dtype=mx.float32) * step - start
|
||||
indices = mx.clip(indices, 0, N - 1)
|
||||
|
||||
shape = [1] * ndims
|
||||
shape[dim] = -1
|
||||
|
||||
@@ -30,6 +30,7 @@ def _nearest_indices(N, scale, dim, ndims):
|
||||
|
||||
def _linear_indices(N, scale, align_corners, dim, ndims):
|
||||
indices = _scaled_indices(N, scale, align_corners, dim, ndims)
|
||||
indices = mx.clip(indices, a_min=0, a_max=N - 1)
|
||||
indices_l = mx.floor(indices)
|
||||
indices_r = mx.ceil(indices)
|
||||
weight = indices - indices_l
|
||||
@@ -41,6 +42,44 @@ def _linear_indices(N, scale, align_corners, dim, ndims):
|
||||
)
|
||||
|
||||
|
||||
def _cubic_indices(N, scale, align_corners, dim, ndims):
|
||||
indices = _scaled_indices(N, scale, align_corners, dim, ndims)
|
||||
indices_l1 = mx.floor(indices)
|
||||
indices_r1 = mx.floor(indices + 1)
|
||||
indices_l2 = indices_l1 - 1
|
||||
indices_r2 = indices_r1 + 1
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def _get_weight(ind, grid, dist):
|
||||
# PyTorch uses -0.5 for antialiasing=true (compatibility with PIL)
|
||||
# and uses -0.75 for antialiasing=false (compatibility with OpenCV)
|
||||
a = -0.75
|
||||
x = mx.abs(ind - grid)
|
||||
if dist == 1:
|
||||
weight = ((a + 2.0) * x - (a + 3.0)) * x * x + 1
|
||||
else:
|
||||
weight = (((x - 5) * x + 8) * x - 4) * a
|
||||
return weight
|
||||
|
||||
weight_l1 = _get_weight(indices, indices_l1, dist=1)[..., None]
|
||||
weight_r1 = _get_weight(indices, indices_r1, dist=1)[..., None]
|
||||
weight_l2 = _get_weight(indices, indices_l2, dist=2)[..., None]
|
||||
weight_r2 = _get_weight(indices, indices_r2, dist=2)[..., None]
|
||||
|
||||
# padding with border value
|
||||
indices_l1 = mx.clip(indices_l1, a_min=0, a_max=N - 1)
|
||||
indices_r1 = mx.clip(indices_r1, a_min=0, a_max=N - 1)
|
||||
indices_l2 = mx.clip(indices_l2, a_min=0, a_max=N - 1)
|
||||
indices_r2 = mx.clip(indices_r2, a_min=0, a_max=N - 1)
|
||||
|
||||
return (
|
||||
(indices_l1.astype(mx.int32), weight_l1),
|
||||
(indices_r1.astype(mx.int32), weight_r1),
|
||||
(indices_l2.astype(mx.int32), weight_l2),
|
||||
(indices_r2.astype(mx.int32), weight_r2),
|
||||
)
|
||||
|
||||
|
||||
def upsample_nearest(x: mx.array, scale_factor: Tuple):
|
||||
dims = x.ndim - 2
|
||||
if dims != len(scale_factor):
|
||||
@@ -71,7 +110,9 @@ def upsample_nearest(x: mx.array, scale_factor: Tuple):
|
||||
return x[indices]
|
||||
|
||||
|
||||
def upsample_linear(x: mx.array, scale_factor: Tuple, align_corners: bool = False):
|
||||
def _interpolate(
|
||||
x: mx.array, scale_factor: Tuple, indices_fn: Callable, align_corners: bool = False
|
||||
):
|
||||
dims = x.ndim - 2
|
||||
if dims != len(scale_factor):
|
||||
raise ValueError("A scale needs to be provided for each spatial dimension")
|
||||
@@ -81,7 +122,7 @@ def upsample_linear(x: mx.array, scale_factor: Tuple, align_corners: bool = Fals
|
||||
# Compute the sampling grid
|
||||
indices = []
|
||||
for i, (n, s) in enumerate(zip(N, scale_factor)):
|
||||
indices.append(_linear_indices(n, s, align_corners, i, dims))
|
||||
indices.append(indices_fn(n, s, align_corners, i, dims))
|
||||
|
||||
# Sample and compute the weights
|
||||
samples = []
|
||||
@@ -95,6 +136,24 @@ def upsample_linear(x: mx.array, scale_factor: Tuple, align_corners: bool = Fals
|
||||
return sum(wi * xi for wi, xi in zip(weights, samples))
|
||||
|
||||
|
||||
def upsample_linear(x: mx.array, scale_factor: Tuple, align_corners: bool = False):
|
||||
return _interpolate(
|
||||
x=x,
|
||||
scale_factor=scale_factor,
|
||||
indices_fn=_linear_indices,
|
||||
align_corners=align_corners,
|
||||
)
|
||||
|
||||
|
||||
def upsample_cubic(x: mx.array, scale_factor: Tuple, align_corners: bool = False):
|
||||
return _interpolate(
|
||||
x=x,
|
||||
scale_factor=scale_factor,
|
||||
indices_fn=_cubic_indices,
|
||||
align_corners=align_corners,
|
||||
)
|
||||
|
||||
|
||||
class Upsample(Module):
|
||||
r"""Upsample the input signal spatially.
|
||||
|
||||
@@ -105,13 +164,14 @@ class Upsample(Module):
|
||||
For example, an audio signal would be 3D with 1 spatial dimension, an image
|
||||
4D with 2 and so on and so forth.
|
||||
|
||||
There are two upsampling algorithms implemented nearest neighbor upsampling
|
||||
and linear interpolation. Both can be applied to any number of spatial
|
||||
dimensions and the linear interpolation will be bilinear, trilinear etc
|
||||
when applied to more than one spatial dimension.
|
||||
There are three upsampling algorithms implemented nearest neighbor upsampling,
|
||||
linear interpolation, and cubic interpolation. All can be applied to any number
|
||||
of spatial dimensions. The linear interpolation will be bilinear, trilinear etc
|
||||
when applied to more than one spatial dimension. And cubic interpolation will be
|
||||
bicubic when there are 2 spatial dimensions.
|
||||
|
||||
.. note::
|
||||
When using one of the linear interpolation modes the ``align_corners``
|
||||
When using one of the linear or cubic interpolation modes the ``align_corners``
|
||||
argument changes how the corners are treated in the input image. If
|
||||
``align_corners=True`` then the top and left edge of the input and
|
||||
output will be matching as will the bottom right edge.
|
||||
@@ -121,10 +181,10 @@ class Upsample(Module):
|
||||
If a ``float`` is provided, it is the multiplier for all spatial dimensions.
|
||||
Otherwise, the number of scale factors provided must match the
|
||||
number of spatial dimensions.
|
||||
mode (str, optional): The upsampling algorithm, either ``"nearest"`` or
|
||||
``"linear"``. Default: ``"nearest"``.
|
||||
mode (str, optional): The upsampling algorithm, either ``"nearest"``,
|
||||
``"linear"`` or ``"cubic"``. Default: ``"nearest"``.
|
||||
align_corners (bool, optional): Changes the way the corners are treated
|
||||
during ``"linear"`` upsampling. See the note above and the
|
||||
during ``"linear"`` and ``"cubic"`` upsampling. See the note above and the
|
||||
examples below for more details. Default: ``False``.
|
||||
|
||||
Examples:
|
||||
@@ -163,7 +223,7 @@ class Upsample(Module):
|
||||
align_corners: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
if mode not in ["nearest", "linear"]:
|
||||
if mode not in ["nearest", "linear", "cubic"]:
|
||||
raise ValueError(f"[Upsample] Got unsupported upsampling algorithm: {mode}")
|
||||
if isinstance(scale_factor, (list, tuple)):
|
||||
self.scale_factor = tuple(map(float, scale_factor))
|
||||
@@ -200,6 +260,9 @@ class Upsample(Module):
|
||||
|
||||
if self.mode == "nearest":
|
||||
return upsample_nearest(x, scale_factor)
|
||||
|
||||
else:
|
||||
elif self.mode == "linear":
|
||||
return upsample_linear(x, scale_factor, self.align_corners)
|
||||
elif self.mode == "cubic":
|
||||
return upsample_cubic(x, scale_factor, self.align_corners)
|
||||
else:
|
||||
raise Exception(f"Unknown interpolation mode: {self.mode}")
|
||||
|
Reference in New Issue
Block a user