mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Upsample2d (#414)
Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com> Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
parent
d729a1991b
commit
22364c40b7
@ -11,7 +11,7 @@ MLX was developed with contributions from the following individuals:
|
||||
- Juarez Bochi: Fixed bug in cross attention.
|
||||
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
||||
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support.
|
||||
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented ``MaxPool1d``, ``MaxPool2d``, ``AvgPool1d``, ``AvgPool2d``.
|
||||
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented pooling layers and ``Upsample``.
|
||||
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
|
||||
|
||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||
|
@ -40,3 +40,4 @@ Layers
|
||||
Softshrink
|
||||
Step
|
||||
Transformer
|
||||
Upsample
|
@ -67,3 +67,4 @@ from mlx.nn.layers.transformer import (
|
||||
TransformerEncoder,
|
||||
TransformerEncoderLayer,
|
||||
)
|
||||
from mlx.nn.layers.upsample import Upsample
|
||||
|
205
python/mlx/nn/layers/upsample.py
Normal file
205
python/mlx/nn/layers/upsample.py
Normal file
@ -0,0 +1,205 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import operator
|
||||
from functools import reduce
|
||||
from itertools import product
|
||||
from typing import Literal, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.nn.layers.base import Module
|
||||
|
||||
|
||||
def _scaled_indices(N, scale, align_corners, dim, ndims):
|
||||
M = int(scale * N)
|
||||
if align_corners:
|
||||
indices = mx.arange(M, dtype=mx.float32) * ((N - 1) / (M - 1))
|
||||
else:
|
||||
step = 1 / scale
|
||||
start = ((M - 1) * step - N + 1) / 2
|
||||
indices = mx.arange(M, dtype=mx.float32) * step - start
|
||||
indices = mx.clip(indices, 0, N - 1)
|
||||
shape = [1] * ndims
|
||||
shape[dim] = -1
|
||||
|
||||
return indices.reshape(shape)
|
||||
|
||||
|
||||
def _nearest_indices(N, scale, dim, ndims):
|
||||
return _scaled_indices(N, scale, True, dim, ndims).astype(mx.int32)
|
||||
|
||||
|
||||
def _linear_indices(N, scale, align_corners, dim, ndims):
|
||||
indices = _scaled_indices(N, scale, align_corners, dim, ndims)
|
||||
indices_l = mx.floor(indices)
|
||||
indices_r = mx.ceil(indices)
|
||||
weight = indices - indices_l
|
||||
weight = mx.expand_dims(weight, -1)
|
||||
|
||||
return (
|
||||
(indices_l.astype(mx.int32), 1 - weight),
|
||||
(indices_r.astype(mx.int32), weight),
|
||||
)
|
||||
|
||||
|
||||
def upsample_nearest(x: mx.array, scale_factor: Tuple):
|
||||
dims = x.ndim - 2
|
||||
if dims != len(scale_factor):
|
||||
raise ValueError("A scale needs to be provided for each spatial dimension")
|
||||
|
||||
# Integer scale_factors means we can simply expand-broadcast and reshape
|
||||
if tuple(map(int, scale_factor)) == scale_factor:
|
||||
shape = list(x.shape)
|
||||
for d in range(dims):
|
||||
shape.insert(2 + 2 * d, 1)
|
||||
x = x.reshape(shape)
|
||||
for d in range(dims):
|
||||
shape[2 + 2 * d] = int(scale_factor[d])
|
||||
x = mx.broadcast_to(x, shape)
|
||||
for d in range(dims):
|
||||
shape[d + 1] *= shape[d + 2]
|
||||
shape.pop(d + 2)
|
||||
x = x.reshape(shape)
|
||||
return x
|
||||
|
||||
else:
|
||||
B, *N, C = x.shape
|
||||
indices = [slice(None)]
|
||||
for i, (n, s) in enumerate(zip(N, scale_factor)):
|
||||
indices.append(_nearest_indices(n, s, i, dims))
|
||||
indices = tuple(indices)
|
||||
|
||||
return x[indices]
|
||||
|
||||
|
||||
def upsample_linear(x: mx.array, scale_factor: Tuple, align_corners: bool = False):
|
||||
dims = x.ndim - 2
|
||||
if dims != len(scale_factor):
|
||||
raise ValueError("A scale needs to be provided for each spatial dimension")
|
||||
|
||||
B, *N, C = x.shape
|
||||
|
||||
# Compute the sampling grid
|
||||
indices = []
|
||||
for i, (n, s) in enumerate(zip(N, scale_factor)):
|
||||
indices.append(_linear_indices(n, s, align_corners, i, dims))
|
||||
|
||||
# Sample and compute the weights
|
||||
samples = []
|
||||
weights = []
|
||||
for idx_weight in product(*indices):
|
||||
idx, weight = zip(*idx_weight)
|
||||
samples.append(x[(slice(None),) + idx])
|
||||
weights.append(reduce(operator.mul, weight))
|
||||
|
||||
# Interpolate
|
||||
return sum(wi * xi for wi, xi in zip(weights, samples))
|
||||
|
||||
|
||||
class Upsample(Module):
|
||||
r"""Upsample the input signal spatially.
|
||||
|
||||
The spatial dimensions are by convention dimensions ``1`` to ``x.ndim -
|
||||
2``. The first is the batch dimension and the last is the feature
|
||||
dimension.
|
||||
|
||||
For example, an audio signal would be 3D with 1 spatial dimension, an image
|
||||
4D with 2 and so on and so forth.
|
||||
|
||||
There are two upsampling algorithms implemented nearest neighbor upsampling
|
||||
and linear interpolation. Both can be applied to any number of spatial
|
||||
dimensions and the linear interpolation will be bilinear, trilinear etc
|
||||
when applied to more than one spatial dimension.
|
||||
|
||||
.. note::
|
||||
When using one of the linear interpolation modes the ``align_corners``
|
||||
argument changes how the corners are treated in the input image. If
|
||||
``align_corners=True`` then the top and left edge of the input and
|
||||
output will be matching as will the bottom right edge.
|
||||
|
||||
Parameters:
|
||||
scale_factor (float or tuple): The multiplier for the spatial size.
|
||||
If a ``float`` is provided, it is the multiplier for all spatial dimensions.
|
||||
Otherwise, the number of scale factors provided must match the
|
||||
number of spatial dimensions.
|
||||
mode (str, optional): The upsampling algorithm, either ``"nearest"`` or
|
||||
``"linear"``. Default: ``"nearest"``.
|
||||
align_corners (bool, optional): Changes the way the corners are treated
|
||||
during ``"linear"`` upsampling. See the note above and the
|
||||
examples below for more details. Default: ``False``.
|
||||
|
||||
Examples:
|
||||
>>> import mlx.core as mx
|
||||
>>> import mlx.nn as nn
|
||||
>>> x = mx.arange(1, 5).reshape((1, 2, 2, 1))
|
||||
>>> x
|
||||
array([[[[1],
|
||||
[2]],
|
||||
[[3],
|
||||
[4]]]], dtype=int32)
|
||||
>>> n = nn.Upsample(scale_factor=2, mode='nearest')
|
||||
>>> n(x).squeeze()
|
||||
array([[1, 1, 2, 2],
|
||||
[1, 1, 2, 2],
|
||||
[3, 3, 4, 4],
|
||||
[3, 3, 4, 4]], dtype=int32)
|
||||
>>> b = nn.Upsample(scale_factor=2, mode='linear')
|
||||
>>> b(x).squeeze()
|
||||
array([[1, 1.25, 1.75, 2],
|
||||
[1.5, 1.75, 2.25, 2.5],
|
||||
[2.5, 2.75, 3.25, 3.5],
|
||||
[3, 3.25, 3.75, 4]], dtype=float32)
|
||||
>>> b = nn.Upsample(scale_factor=2, mode='linear', align_corners=True)
|
||||
>>> b(x).squeeze()
|
||||
array([[1, 1.33333, 1.66667, 2],
|
||||
[1.66667, 2, 2.33333, 2.66667],
|
||||
[2.33333, 2.66667, 3, 3.33333],
|
||||
[3, 3.33333, 3.66667, 4]], dtype=float32)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scale_factor: Union[float, Tuple],
|
||||
mode: Literal["nearest", "linear"] = "nearest",
|
||||
align_corners: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
if mode not in ["nearest", "linear"]:
|
||||
raise ValueError(f"[Upsample] Got unsupported upsampling algorithm: {mode}")
|
||||
if isinstance(scale_factor, (list, tuple)):
|
||||
self.scale_factor = tuple(map(float, scale_factor))
|
||||
else:
|
||||
self.scale_factor = float(scale_factor)
|
||||
self.mode = mode
|
||||
self.align_corners = align_corners
|
||||
|
||||
def _extra_repr(self) -> str:
|
||||
return (
|
||||
f"scale_factor={self.scale_factor}, mode={self.mode!r}, "
|
||||
f"align_corners={self.align_corners}"
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
dims = x.ndim - 2
|
||||
if dims <= 0:
|
||||
raise ValueError(
|
||||
f"[Upsample] The input should have at least 1 spatial "
|
||||
f"dimension which means it should be at least 3D but "
|
||||
f"{x.ndim}D was provided"
|
||||
)
|
||||
|
||||
scale_factor = self.scale_factor
|
||||
if isinstance(scale_factor, tuple):
|
||||
if len(scale_factor) != dims:
|
||||
raise ValueError(
|
||||
f"[Upsample] One scale per spatial dimension is required but "
|
||||
f"scale_factor={scale_factor} and the number of spatial "
|
||||
f"dimensions were {dims}"
|
||||
)
|
||||
else:
|
||||
scale_factor = (scale_factor,) * dims
|
||||
|
||||
if self.mode == "nearest":
|
||||
return upsample_nearest(x, scale_factor)
|
||||
|
||||
else:
|
||||
return upsample_linear(x, scale_factor, self.align_corners)
|
@ -1,4 +1,4 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
@ -8,7 +8,7 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx_tests
|
||||
import numpy as np
|
||||
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
||||
from mlx.utils import tree_flatten, tree_map
|
||||
|
||||
|
||||
class TestBase(mlx_tests.MLXTestCase):
|
||||
@ -905,6 +905,228 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(y.shape, x.shape)
|
||||
self.assertTrue(y.dtype, mx.float16)
|
||||
|
||||
def test_upsample(self):
|
||||
b, h, w, c = 1, 2, 2, 1
|
||||
scale_factor = 2
|
||||
upsample_nearest = nn.Upsample(
|
||||
scale_factor=scale_factor, mode="nearest", align_corners=True
|
||||
)
|
||||
upsample_bilinear = nn.Upsample(
|
||||
scale_factor=scale_factor, mode="linear", align_corners=True
|
||||
)
|
||||
upsample_nearest = nn.Upsample(
|
||||
scale_factor=scale_factor, mode="nearest", align_corners=True
|
||||
)
|
||||
upsample_bilinear_no_align_corners = nn.Upsample(
|
||||
scale_factor=scale_factor, mode="linear", align_corners=False
|
||||
)
|
||||
upsample_nearest_no_align_corners = nn.Upsample(
|
||||
scale_factor=scale_factor, mode="nearest", align_corners=False
|
||||
)
|
||||
# Test single feature map, align corners
|
||||
x = mx.arange(b * h * w * c).reshape((b, c, h, w)).transpose((0, 2, 3, 1))
|
||||
expected_nearest = mx.array(
|
||||
[[[[0, 0, 1, 1], [0, 0, 1, 1], [2, 2, 3, 3], [2, 2, 3, 3]]]]
|
||||
).transpose((0, 2, 3, 1))
|
||||
expected_bilinear = mx.array(
|
||||
[
|
||||
[
|
||||
[
|
||||
[0, 0.333333, 0.666667, 1],
|
||||
[0.666667, 1, 1.33333, 1.66667],
|
||||
[1.33333, 1.66667, 2, 2.33333],
|
||||
[2, 2.33333, 2.66667, 3],
|
||||
]
|
||||
]
|
||||
]
|
||||
).transpose((0, 2, 3, 1))
|
||||
# Test single feature map, no align corners
|
||||
x = (
|
||||
mx.arange(1, b * h * w * c + 1)
|
||||
.reshape((b, c, h, w))
|
||||
.transpose((0, 2, 3, 1))
|
||||
)
|
||||
expected_bilinear_no_align_corners = mx.array(
|
||||
[
|
||||
[
|
||||
[
|
||||
[1.0000, 1.2500, 1.7500, 2.0000],
|
||||
[1.5000, 1.7500, 2.2500, 2.5000],
|
||||
[2.5000, 2.7500, 3.2500, 3.5000],
|
||||
[3.0000, 3.2500, 3.7500, 4.0000],
|
||||
]
|
||||
]
|
||||
]
|
||||
).transpose((0, 2, 3, 1))
|
||||
expected_nearest_no_align_corners = mx.array(
|
||||
[[[[1, 1, 2, 2], [1, 1, 2, 2], [3, 3, 4, 4], [3, 3, 4, 4]]]]
|
||||
).transpose((0, 2, 3, 1))
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
upsample_nearest_no_align_corners(x), expected_nearest_no_align_corners
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
upsample_bilinear_no_align_corners(x),
|
||||
expected_bilinear_no_align_corners,
|
||||
)
|
||||
)
|
||||
|
||||
# Test a more complex batch
|
||||
b, h, w, c = 2, 3, 3, 2
|
||||
scale_factor = 2
|
||||
x = mx.arange((b * h * w * c)).reshape((b, c, h, w)).transpose((0, 2, 3, 1))
|
||||
|
||||
upsample_nearest = nn.Upsample(
|
||||
scale_factor=scale_factor, mode="nearest", align_corners=True
|
||||
)
|
||||
upsample_bilinear = nn.Upsample(
|
||||
scale_factor=scale_factor, mode="linear", align_corners=True
|
||||
)
|
||||
|
||||
expected_nearest = mx.array(
|
||||
[
|
||||
[
|
||||
[
|
||||
[0.0, 0.0, 1.0, 1.0, 2.0, 2.0],
|
||||
[0.0, 0.0, 1.0, 1.0, 2.0, 2.0],
|
||||
[3.0, 3.0, 4.0, 4.0, 5.0, 5.0],
|
||||
[3.0, 3.0, 4.0, 4.0, 5.0, 5.0],
|
||||
[6.0, 6.0, 7.0, 7.0, 8.0, 8.0],
|
||||
[6.0, 6.0, 7.0, 7.0, 8.0, 8.0],
|
||||
],
|
||||
[
|
||||
[9.0, 9.0, 10.0, 10.0, 11.0, 11.0],
|
||||
[9.0, 9.0, 10.0, 10.0, 11.0, 11.0],
|
||||
[12.0, 12.0, 13.0, 13.0, 14.0, 14.0],
|
||||
[12.0, 12.0, 13.0, 13.0, 14.0, 14.0],
|
||||
[15.0, 15.0, 16.0, 16.0, 17.0, 17.0],
|
||||
[15.0, 15.0, 16.0, 16.0, 17.0, 17.0],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[18.0, 18.0, 19.0, 19.0, 20.0, 20.0],
|
||||
[18.0, 18.0, 19.0, 19.0, 20.0, 20.0],
|
||||
[21.0, 21.0, 22.0, 22.0, 23.0, 23.0],
|
||||
[21.0, 21.0, 22.0, 22.0, 23.0, 23.0],
|
||||
[24.0, 24.0, 25.0, 25.0, 26.0, 26.0],
|
||||
[24.0, 24.0, 25.0, 25.0, 26.0, 26.0],
|
||||
],
|
||||
[
|
||||
[27.0, 27.0, 28.0, 28.0, 29.0, 29.0],
|
||||
[27.0, 27.0, 28.0, 28.0, 29.0, 29.0],
|
||||
[30.0, 30.0, 31.0, 31.0, 32.0, 32.0],
|
||||
[30.0, 30.0, 31.0, 31.0, 32.0, 32.0],
|
||||
[33.0, 33.0, 34.0, 34.0, 35.0, 35.0],
|
||||
[33.0, 33.0, 34.0, 34.0, 35.0, 35.0],
|
||||
],
|
||||
],
|
||||
]
|
||||
).transpose((0, 2, 3, 1))
|
||||
expected_bilinear = mx.array(
|
||||
[
|
||||
[
|
||||
[
|
||||
[0.0, 0.4, 0.8, 1.2, 1.6, 2.0],
|
||||
[1.2, 1.6, 2.0, 2.4, 2.8, 3.2],
|
||||
[2.4, 2.8, 3.2, 3.6, 4.0, 4.4],
|
||||
[3.6, 4.0, 4.4, 4.8, 5.2, 5.6],
|
||||
[4.8, 5.2, 5.6, 6.0, 6.4, 6.8],
|
||||
[6.0, 6.4, 6.8, 7.2, 7.6, 8.0],
|
||||
],
|
||||
[
|
||||
[9.0, 9.4, 9.8, 10.2, 10.6, 11.0],
|
||||
[10.2, 10.6, 11.0, 11.4, 11.8, 12.2],
|
||||
[11.4, 11.8, 12.2, 12.6, 13.0, 13.4],
|
||||
[12.6, 13.0, 13.4, 13.8, 14.2, 14.6],
|
||||
[13.8, 14.2, 14.6, 15.0, 15.4, 15.8],
|
||||
[15.0, 15.4, 15.8, 16.2, 16.6, 17.0],
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
[18.0, 18.4, 18.8, 19.2, 19.6, 20.0],
|
||||
[19.2, 19.6, 20.0, 20.4, 20.8, 21.2],
|
||||
[20.4, 20.8, 21.2, 21.6, 22.0, 22.4],
|
||||
[21.6, 22.0, 22.4, 22.8, 23.2, 23.6],
|
||||
[22.8, 23.2, 23.6, 24.0, 24.4, 24.8],
|
||||
[24.0, 24.4, 24.8, 25.2, 25.6, 26.0],
|
||||
],
|
||||
[
|
||||
[27.0, 27.4, 27.8, 28.2, 28.6, 29.0],
|
||||
[28.2, 28.6, 29.0, 29.4, 29.8, 30.2],
|
||||
[29.4, 29.8, 30.2, 30.6, 31.0, 31.4],
|
||||
[30.6, 31.0, 31.4, 31.8, 32.2, 32.6],
|
||||
[31.8, 32.2, 32.6, 33.0, 33.4, 33.8],
|
||||
[33.0, 33.4, 33.8, 34.2, 34.6, 35.0],
|
||||
],
|
||||
],
|
||||
]
|
||||
).transpose((0, 2, 3, 1))
|
||||
self.assertTrue(np.allclose(upsample_nearest(x), expected_nearest))
|
||||
self.assertTrue(np.allclose(upsample_bilinear(x), expected_bilinear))
|
||||
|
||||
# Test different height and width scale_factor
|
||||
b, h, w, c = 1, 2, 2, 2
|
||||
x = mx.arange(b * h * w * c).reshape((b, c, h, w)).transpose((0, 2, 3, 1))
|
||||
upsample_nearest = nn.Upsample(
|
||||
scale_factor=(2, 3), mode="nearest", align_corners=True
|
||||
)
|
||||
upsample_bilinear = nn.Upsample(
|
||||
scale_factor=(2, 3), mode="linear", align_corners=True
|
||||
)
|
||||
|
||||
expected_nearest = mx.array(
|
||||
[
|
||||
[
|
||||
[
|
||||
[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1],
|
||||
[2, 2, 2, 3, 3, 3],
|
||||
[2, 2, 2, 3, 3, 3],
|
||||
],
|
||||
[
|
||||
[4, 4, 4, 5, 5, 5],
|
||||
[4, 4, 4, 5, 5, 5],
|
||||
[6, 6, 6, 7, 7, 7],
|
||||
[6, 6, 6, 7, 7, 7],
|
||||
],
|
||||
]
|
||||
]
|
||||
).transpose((0, 2, 3, 1))
|
||||
expected_bilinear = mx.array(
|
||||
[
|
||||
[
|
||||
[
|
||||
[0, 0.2, 0.4, 0.6, 0.8, 1],
|
||||
[0.666667, 0.866667, 1.06667, 1.26667, 1.46667, 1.66667],
|
||||
[1.33333, 1.53333, 1.73333, 1.93333, 2.13333, 2.33333],
|
||||
[2, 2.2, 2.4, 2.6, 2.8, 3],
|
||||
],
|
||||
[
|
||||
[4, 4.2, 4.4, 4.6, 4.8, 5],
|
||||
[4.66667, 4.86667, 5.06667, 5.26667, 5.46667, 5.66667],
|
||||
[5.33333, 5.53333, 5.73333, 5.93333, 6.13333, 6.33333],
|
||||
[6, 6.2, 6.4, 6.6, 6.8, 7],
|
||||
],
|
||||
]
|
||||
]
|
||||
).transpose((0, 2, 3, 1))
|
||||
self.assertTrue(np.allclose(upsample_nearest(x), expected_nearest))
|
||||
self.assertTrue(np.allclose(upsample_bilinear(x), expected_bilinear))
|
||||
|
||||
# Test repr
|
||||
self.assertEqual(
|
||||
str(nn.Upsample(scale_factor=2)),
|
||||
"Upsample(scale_factor=2.0, mode='nearest', align_corners=False)",
|
||||
)
|
||||
self.assertEqual(
|
||||
str(nn.Upsample(scale_factor=(2, 3))),
|
||||
"Upsample(scale_factor=(2.0, 3.0), mode='nearest', align_corners=False)",
|
||||
)
|
||||
|
||||
def test_pooling(self):
|
||||
# Test 1d pooling
|
||||
x = mx.array(
|
||||
|
Loading…
Reference in New Issue
Block a user