mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-10 19:26:42 +08:00
269 lines
9.4 KiB
Python
269 lines
9.4 KiB
Python
# Copyright © 2023-2024 Apple Inc.
|
|
|
|
import operator
|
|
from functools import partial, reduce
|
|
from itertools import product
|
|
from typing import Callable, Literal, Tuple, Union
|
|
|
|
import mlx.core as mx
|
|
from mlx.nn.layers.base import Module
|
|
|
|
|
|
def _scaled_indices(N, scale, align_corners, dim, ndims):
|
|
M = int(scale * N)
|
|
if align_corners:
|
|
indices = mx.arange(M, dtype=mx.float32) * ((N - 1) / (M - 1))
|
|
else:
|
|
step = 1 / scale
|
|
start = ((M - 1) * step - N + 1) / 2
|
|
indices = mx.arange(M, dtype=mx.float32) * step - start
|
|
|
|
shape = [1] * ndims
|
|
shape[dim] = -1
|
|
|
|
return indices.reshape(shape)
|
|
|
|
|
|
def _nearest_indices(N, scale, dim, ndims):
|
|
return _scaled_indices(N, scale, True, dim, ndims).astype(mx.int32)
|
|
|
|
|
|
def _linear_indices(N, scale, align_corners, dim, ndims):
|
|
indices = _scaled_indices(N, scale, align_corners, dim, ndims)
|
|
indices = mx.clip(indices, a_min=0, a_max=N - 1)
|
|
indices_l = mx.floor(indices)
|
|
indices_r = mx.ceil(indices)
|
|
weight = indices - indices_l
|
|
weight = mx.expand_dims(weight, -1)
|
|
|
|
return (
|
|
(indices_l.astype(mx.int32), 1 - weight),
|
|
(indices_r.astype(mx.int32), weight),
|
|
)
|
|
|
|
|
|
def _cubic_indices(N, scale, align_corners, dim, ndims):
|
|
indices = _scaled_indices(N, scale, align_corners, dim, ndims)
|
|
indices_l1 = mx.floor(indices)
|
|
indices_r1 = mx.floor(indices + 1)
|
|
indices_l2 = indices_l1 - 1
|
|
indices_r2 = indices_r1 + 1
|
|
|
|
@partial(mx.compile, shapeless=True)
|
|
def _get_weight(ind, grid, dist):
|
|
# PyTorch uses -0.5 for antialiasing=true (compatibility with PIL)
|
|
# and uses -0.75 for antialiasing=false (compatibility with OpenCV)
|
|
a = -0.75
|
|
x = mx.abs(ind - grid)
|
|
if dist == 1:
|
|
weight = ((a + 2.0) * x - (a + 3.0)) * x * x + 1
|
|
else:
|
|
weight = (((x - 5) * x + 8) * x - 4) * a
|
|
return weight
|
|
|
|
weight_l1 = _get_weight(indices, indices_l1, dist=1)[..., None]
|
|
weight_r1 = _get_weight(indices, indices_r1, dist=1)[..., None]
|
|
weight_l2 = _get_weight(indices, indices_l2, dist=2)[..., None]
|
|
weight_r2 = _get_weight(indices, indices_r2, dist=2)[..., None]
|
|
|
|
# padding with border value
|
|
indices_l1 = mx.clip(indices_l1, a_min=0, a_max=N - 1)
|
|
indices_r1 = mx.clip(indices_r1, a_min=0, a_max=N - 1)
|
|
indices_l2 = mx.clip(indices_l2, a_min=0, a_max=N - 1)
|
|
indices_r2 = mx.clip(indices_r2, a_min=0, a_max=N - 1)
|
|
|
|
return (
|
|
(indices_l1.astype(mx.int32), weight_l1),
|
|
(indices_r1.astype(mx.int32), weight_r1),
|
|
(indices_l2.astype(mx.int32), weight_l2),
|
|
(indices_r2.astype(mx.int32), weight_r2),
|
|
)
|
|
|
|
|
|
def upsample_nearest(x: mx.array, scale_factor: Tuple):
|
|
dims = x.ndim - 2
|
|
if dims != len(scale_factor):
|
|
raise ValueError("A scale needs to be provided for each spatial dimension")
|
|
|
|
# Integer scale_factors means we can simply expand-broadcast and reshape
|
|
if tuple(map(int, scale_factor)) == scale_factor:
|
|
shape = list(x.shape)
|
|
for d in range(dims):
|
|
shape.insert(2 + 2 * d, 1)
|
|
x = x.reshape(shape)
|
|
for d in range(dims):
|
|
shape[2 + 2 * d] = int(scale_factor[d])
|
|
x = mx.broadcast_to(x, shape)
|
|
for d in range(dims):
|
|
shape[d + 1] *= shape[d + 2]
|
|
shape.pop(d + 2)
|
|
x = x.reshape(shape)
|
|
return x
|
|
|
|
else:
|
|
B, *N, C = x.shape
|
|
indices = [slice(None)]
|
|
for i, (n, s) in enumerate(zip(N, scale_factor)):
|
|
indices.append(_nearest_indices(n, s, i, dims))
|
|
indices = tuple(indices)
|
|
|
|
return x[indices]
|
|
|
|
|
|
def _interpolate(
|
|
x: mx.array, scale_factor: Tuple, indices_fn: Callable, align_corners: bool = False
|
|
):
|
|
dims = x.ndim - 2
|
|
if dims != len(scale_factor):
|
|
raise ValueError("A scale needs to be provided for each spatial dimension")
|
|
|
|
B, *N, C = x.shape
|
|
|
|
# Compute the sampling grid
|
|
indices = []
|
|
for i, (n, s) in enumerate(zip(N, scale_factor)):
|
|
indices.append(indices_fn(n, s, align_corners, i, dims))
|
|
|
|
# Sample and compute the weights
|
|
samples = []
|
|
weights = []
|
|
for idx_weight in product(*indices):
|
|
idx, weight = zip(*idx_weight)
|
|
samples.append(x[(slice(None),) + idx])
|
|
weights.append(reduce(operator.mul, weight))
|
|
|
|
# Interpolate
|
|
return sum(wi * xi for wi, xi in zip(weights, samples))
|
|
|
|
|
|
def upsample_linear(x: mx.array, scale_factor: Tuple, align_corners: bool = False):
|
|
return _interpolate(
|
|
x=x,
|
|
scale_factor=scale_factor,
|
|
indices_fn=_linear_indices,
|
|
align_corners=align_corners,
|
|
)
|
|
|
|
|
|
def upsample_cubic(x: mx.array, scale_factor: Tuple, align_corners: bool = False):
|
|
return _interpolate(
|
|
x=x,
|
|
scale_factor=scale_factor,
|
|
indices_fn=_cubic_indices,
|
|
align_corners=align_corners,
|
|
)
|
|
|
|
|
|
class Upsample(Module):
|
|
r"""Upsample the input signal spatially.
|
|
|
|
The spatial dimensions are by convention dimensions ``1`` to ``x.ndim -
|
|
2``. The first is the batch dimension and the last is the feature
|
|
dimension.
|
|
|
|
For example, an audio signal would be 3D with 1 spatial dimension, an image
|
|
4D with 2 and so on and so forth.
|
|
|
|
There are three upsampling algorithms implemented nearest neighbor upsampling,
|
|
linear interpolation, and cubic interpolation. All can be applied to any number
|
|
of spatial dimensions. The linear interpolation will be bilinear, trilinear etc
|
|
when applied to more than one spatial dimension. And cubic interpolation will be
|
|
bicubic when there are 2 spatial dimensions.
|
|
|
|
.. note::
|
|
When using one of the linear or cubic interpolation modes the ``align_corners``
|
|
argument changes how the corners are treated in the input image. If
|
|
``align_corners=True`` then the top and left edge of the input and
|
|
output will be matching as will the bottom right edge.
|
|
|
|
Parameters:
|
|
scale_factor (float or tuple): The multiplier for the spatial size.
|
|
If a ``float`` is provided, it is the multiplier for all spatial dimensions.
|
|
Otherwise, the number of scale factors provided must match the
|
|
number of spatial dimensions.
|
|
mode (str, optional): The upsampling algorithm, either ``"nearest"``,
|
|
``"linear"`` or ``"cubic"``. Default: ``"nearest"``.
|
|
align_corners (bool, optional): Changes the way the corners are treated
|
|
during ``"linear"`` and ``"cubic"`` upsampling. See the note above and the
|
|
examples below for more details. Default: ``False``.
|
|
|
|
Examples:
|
|
>>> import mlx.core as mx
|
|
>>> import mlx.nn as nn
|
|
>>> x = mx.arange(1, 5).reshape((1, 2, 2, 1))
|
|
>>> x
|
|
array([[[[1],
|
|
[2]],
|
|
[[3],
|
|
[4]]]], dtype=int32)
|
|
>>> n = nn.Upsample(scale_factor=2, mode='nearest')
|
|
>>> n(x).squeeze()
|
|
array([[1, 1, 2, 2],
|
|
[1, 1, 2, 2],
|
|
[3, 3, 4, 4],
|
|
[3, 3, 4, 4]], dtype=int32)
|
|
>>> b = nn.Upsample(scale_factor=2, mode='linear')
|
|
>>> b(x).squeeze()
|
|
array([[1, 1.25, 1.75, 2],
|
|
[1.5, 1.75, 2.25, 2.5],
|
|
[2.5, 2.75, 3.25, 3.5],
|
|
[3, 3.25, 3.75, 4]], dtype=float32)
|
|
>>> b = nn.Upsample(scale_factor=2, mode='linear', align_corners=True)
|
|
>>> b(x).squeeze()
|
|
array([[1, 1.33333, 1.66667, 2],
|
|
[1.66667, 2, 2.33333, 2.66667],
|
|
[2.33333, 2.66667, 3, 3.33333],
|
|
[3, 3.33333, 3.66667, 4]], dtype=float32)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
scale_factor: Union[float, Tuple],
|
|
mode: Literal["nearest", "linear"] = "nearest",
|
|
align_corners: bool = False,
|
|
):
|
|
super().__init__()
|
|
if mode not in ["nearest", "linear", "cubic"]:
|
|
raise ValueError(f"[Upsample] Got unsupported upsampling algorithm: {mode}")
|
|
if isinstance(scale_factor, (list, tuple)):
|
|
self.scale_factor = tuple(map(float, scale_factor))
|
|
else:
|
|
self.scale_factor = float(scale_factor)
|
|
self.mode = mode
|
|
self.align_corners = align_corners
|
|
|
|
def _extra_repr(self) -> str:
|
|
return (
|
|
f"scale_factor={self.scale_factor}, mode={self.mode!r}, "
|
|
f"align_corners={self.align_corners}"
|
|
)
|
|
|
|
def __call__(self, x: mx.array) -> mx.array:
|
|
dims = x.ndim - 2
|
|
if dims <= 0:
|
|
raise ValueError(
|
|
f"[Upsample] The input should have at least 1 spatial "
|
|
f"dimension which means it should be at least 3D but "
|
|
f"{x.ndim}D was provided"
|
|
)
|
|
|
|
scale_factor = self.scale_factor
|
|
if isinstance(scale_factor, tuple):
|
|
if len(scale_factor) != dims:
|
|
raise ValueError(
|
|
f"[Upsample] One scale per spatial dimension is required but "
|
|
f"scale_factor={scale_factor} and the number of spatial "
|
|
f"dimensions were {dims}"
|
|
)
|
|
else:
|
|
scale_factor = (scale_factor,) * dims
|
|
|
|
if self.mode == "nearest":
|
|
return upsample_nearest(x, scale_factor)
|
|
elif self.mode == "linear":
|
|
return upsample_linear(x, scale_factor, self.align_corners)
|
|
elif self.mode == "cubic":
|
|
return upsample_cubic(x, scale_factor, self.align_corners)
|
|
else:
|
|
raise Exception(f"Unknown interpolation mode: {self.mode}")
|