Upsample2d (#414)

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
Gabrijel Boduljak 2024-02-23 18:55:04 +01:00 committed by GitHub
parent d729a1991b
commit 22364c40b7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 433 additions and 4 deletions

View File

@ -11,7 +11,7 @@ MLX was developed with contributions from the following individuals:
- Juarez Bochi: Fixed bug in cross attention.
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support.
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented ``MaxPool1d``, ``MaxPool2d``, ``AvgPool1d``, ``AvgPool2d``.
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented pooling layers and ``Upsample``.
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
@ -253,4 +253,4 @@ Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.

View File

@ -40,3 +40,4 @@ Layers
Softshrink
Step
Transformer
Upsample

View File

@ -67,3 +67,4 @@ from mlx.nn.layers.transformer import (
TransformerEncoder,
TransformerEncoderLayer,
)
from mlx.nn.layers.upsample import Upsample

View File

@ -0,0 +1,205 @@
# Copyright © 2023-2024 Apple Inc.
import operator
from functools import reduce
from itertools import product
from typing import Literal, Tuple, Union
import mlx.core as mx
from mlx.nn.layers.base import Module
def _scaled_indices(N, scale, align_corners, dim, ndims):
M = int(scale * N)
if align_corners:
indices = mx.arange(M, dtype=mx.float32) * ((N - 1) / (M - 1))
else:
step = 1 / scale
start = ((M - 1) * step - N + 1) / 2
indices = mx.arange(M, dtype=mx.float32) * step - start
indices = mx.clip(indices, 0, N - 1)
shape = [1] * ndims
shape[dim] = -1
return indices.reshape(shape)
def _nearest_indices(N, scale, dim, ndims):
return _scaled_indices(N, scale, True, dim, ndims).astype(mx.int32)
def _linear_indices(N, scale, align_corners, dim, ndims):
indices = _scaled_indices(N, scale, align_corners, dim, ndims)
indices_l = mx.floor(indices)
indices_r = mx.ceil(indices)
weight = indices - indices_l
weight = mx.expand_dims(weight, -1)
return (
(indices_l.astype(mx.int32), 1 - weight),
(indices_r.astype(mx.int32), weight),
)
def upsample_nearest(x: mx.array, scale_factor: Tuple):
dims = x.ndim - 2
if dims != len(scale_factor):
raise ValueError("A scale needs to be provided for each spatial dimension")
# Integer scale_factors means we can simply expand-broadcast and reshape
if tuple(map(int, scale_factor)) == scale_factor:
shape = list(x.shape)
for d in range(dims):
shape.insert(2 + 2 * d, 1)
x = x.reshape(shape)
for d in range(dims):
shape[2 + 2 * d] = int(scale_factor[d])
x = mx.broadcast_to(x, shape)
for d in range(dims):
shape[d + 1] *= shape[d + 2]
shape.pop(d + 2)
x = x.reshape(shape)
return x
else:
B, *N, C = x.shape
indices = [slice(None)]
for i, (n, s) in enumerate(zip(N, scale_factor)):
indices.append(_nearest_indices(n, s, i, dims))
indices = tuple(indices)
return x[indices]
def upsample_linear(x: mx.array, scale_factor: Tuple, align_corners: bool = False):
dims = x.ndim - 2
if dims != len(scale_factor):
raise ValueError("A scale needs to be provided for each spatial dimension")
B, *N, C = x.shape
# Compute the sampling grid
indices = []
for i, (n, s) in enumerate(zip(N, scale_factor)):
indices.append(_linear_indices(n, s, align_corners, i, dims))
# Sample and compute the weights
samples = []
weights = []
for idx_weight in product(*indices):
idx, weight = zip(*idx_weight)
samples.append(x[(slice(None),) + idx])
weights.append(reduce(operator.mul, weight))
# Interpolate
return sum(wi * xi for wi, xi in zip(weights, samples))
class Upsample(Module):
r"""Upsample the input signal spatially.
The spatial dimensions are by convention dimensions ``1`` to ``x.ndim -
2``. The first is the batch dimension and the last is the feature
dimension.
For example, an audio signal would be 3D with 1 spatial dimension, an image
4D with 2 and so on and so forth.
There are two upsampling algorithms implemented nearest neighbor upsampling
and linear interpolation. Both can be applied to any number of spatial
dimensions and the linear interpolation will be bilinear, trilinear etc
when applied to more than one spatial dimension.
.. note::
When using one of the linear interpolation modes the ``align_corners``
argument changes how the corners are treated in the input image. If
``align_corners=True`` then the top and left edge of the input and
output will be matching as will the bottom right edge.
Parameters:
scale_factor (float or tuple): The multiplier for the spatial size.
If a ``float`` is provided, it is the multiplier for all spatial dimensions.
Otherwise, the number of scale factors provided must match the
number of spatial dimensions.
mode (str, optional): The upsampling algorithm, either ``"nearest"`` or
``"linear"``. Default: ``"nearest"``.
align_corners (bool, optional): Changes the way the corners are treated
during ``"linear"`` upsampling. See the note above and the
examples below for more details. Default: ``False``.
Examples:
>>> import mlx.core as mx
>>> import mlx.nn as nn
>>> x = mx.arange(1, 5).reshape((1, 2, 2, 1))
>>> x
array([[[[1],
[2]],
[[3],
[4]]]], dtype=int32)
>>> n = nn.Upsample(scale_factor=2, mode='nearest')
>>> n(x).squeeze()
array([[1, 1, 2, 2],
[1, 1, 2, 2],
[3, 3, 4, 4],
[3, 3, 4, 4]], dtype=int32)
>>> b = nn.Upsample(scale_factor=2, mode='linear')
>>> b(x).squeeze()
array([[1, 1.25, 1.75, 2],
[1.5, 1.75, 2.25, 2.5],
[2.5, 2.75, 3.25, 3.5],
[3, 3.25, 3.75, 4]], dtype=float32)
>>> b = nn.Upsample(scale_factor=2, mode='linear', align_corners=True)
>>> b(x).squeeze()
array([[1, 1.33333, 1.66667, 2],
[1.66667, 2, 2.33333, 2.66667],
[2.33333, 2.66667, 3, 3.33333],
[3, 3.33333, 3.66667, 4]], dtype=float32)
"""
def __init__(
self,
scale_factor: Union[float, Tuple],
mode: Literal["nearest", "linear"] = "nearest",
align_corners: bool = False,
):
super().__init__()
if mode not in ["nearest", "linear"]:
raise ValueError(f"[Upsample] Got unsupported upsampling algorithm: {mode}")
if isinstance(scale_factor, (list, tuple)):
self.scale_factor = tuple(map(float, scale_factor))
else:
self.scale_factor = float(scale_factor)
self.mode = mode
self.align_corners = align_corners
def _extra_repr(self) -> str:
return (
f"scale_factor={self.scale_factor}, mode={self.mode!r}, "
f"align_corners={self.align_corners}"
)
def __call__(self, x: mx.array) -> mx.array:
dims = x.ndim - 2
if dims <= 0:
raise ValueError(
f"[Upsample] The input should have at least 1 spatial "
f"dimension which means it should be at least 3D but "
f"{x.ndim}D was provided"
)
scale_factor = self.scale_factor
if isinstance(scale_factor, tuple):
if len(scale_factor) != dims:
raise ValueError(
f"[Upsample] One scale per spatial dimension is required but "
f"scale_factor={scale_factor} and the number of spatial "
f"dimensions were {dims}"
)
else:
scale_factor = (scale_factor,) * dims
if self.mode == "nearest":
return upsample_nearest(x, scale_factor)
else:
return upsample_linear(x, scale_factor, self.align_corners)

View File

@ -1,4 +1,4 @@
# Copyright © 2023 Apple Inc.
# Copyright © 2023-2024 Apple Inc.
import os
import tempfile
@ -8,7 +8,7 @@ import mlx.core as mx
import mlx.nn as nn
import mlx_tests
import numpy as np
from mlx.utils import tree_flatten, tree_map, tree_unflatten
from mlx.utils import tree_flatten, tree_map
class TestBase(mlx_tests.MLXTestCase):
@ -905,6 +905,228 @@ class TestLayers(mlx_tests.MLXTestCase):
self.assertTrue(y.shape, x.shape)
self.assertTrue(y.dtype, mx.float16)
def test_upsample(self):
b, h, w, c = 1, 2, 2, 1
scale_factor = 2
upsample_nearest = nn.Upsample(
scale_factor=scale_factor, mode="nearest", align_corners=True
)
upsample_bilinear = nn.Upsample(
scale_factor=scale_factor, mode="linear", align_corners=True
)
upsample_nearest = nn.Upsample(
scale_factor=scale_factor, mode="nearest", align_corners=True
)
upsample_bilinear_no_align_corners = nn.Upsample(
scale_factor=scale_factor, mode="linear", align_corners=False
)
upsample_nearest_no_align_corners = nn.Upsample(
scale_factor=scale_factor, mode="nearest", align_corners=False
)
# Test single feature map, align corners
x = mx.arange(b * h * w * c).reshape((b, c, h, w)).transpose((0, 2, 3, 1))
expected_nearest = mx.array(
[[[[0, 0, 1, 1], [0, 0, 1, 1], [2, 2, 3, 3], [2, 2, 3, 3]]]]
).transpose((0, 2, 3, 1))
expected_bilinear = mx.array(
[
[
[
[0, 0.333333, 0.666667, 1],
[0.666667, 1, 1.33333, 1.66667],
[1.33333, 1.66667, 2, 2.33333],
[2, 2.33333, 2.66667, 3],
]
]
]
).transpose((0, 2, 3, 1))
# Test single feature map, no align corners
x = (
mx.arange(1, b * h * w * c + 1)
.reshape((b, c, h, w))
.transpose((0, 2, 3, 1))
)
expected_bilinear_no_align_corners = mx.array(
[
[
[
[1.0000, 1.2500, 1.7500, 2.0000],
[1.5000, 1.7500, 2.2500, 2.5000],
[2.5000, 2.7500, 3.2500, 3.5000],
[3.0000, 3.2500, 3.7500, 4.0000],
]
]
]
).transpose((0, 2, 3, 1))
expected_nearest_no_align_corners = mx.array(
[[[[1, 1, 2, 2], [1, 1, 2, 2], [3, 3, 4, 4], [3, 3, 4, 4]]]]
).transpose((0, 2, 3, 1))
self.assertTrue(
np.allclose(
upsample_nearest_no_align_corners(x), expected_nearest_no_align_corners
)
)
self.assertTrue(
np.allclose(
upsample_bilinear_no_align_corners(x),
expected_bilinear_no_align_corners,
)
)
# Test a more complex batch
b, h, w, c = 2, 3, 3, 2
scale_factor = 2
x = mx.arange((b * h * w * c)).reshape((b, c, h, w)).transpose((0, 2, 3, 1))
upsample_nearest = nn.Upsample(
scale_factor=scale_factor, mode="nearest", align_corners=True
)
upsample_bilinear = nn.Upsample(
scale_factor=scale_factor, mode="linear", align_corners=True
)
expected_nearest = mx.array(
[
[
[
[0.0, 0.0, 1.0, 1.0, 2.0, 2.0],
[0.0, 0.0, 1.0, 1.0, 2.0, 2.0],
[3.0, 3.0, 4.0, 4.0, 5.0, 5.0],
[3.0, 3.0, 4.0, 4.0, 5.0, 5.0],
[6.0, 6.0, 7.0, 7.0, 8.0, 8.0],
[6.0, 6.0, 7.0, 7.0, 8.0, 8.0],
],
[
[9.0, 9.0, 10.0, 10.0, 11.0, 11.0],
[9.0, 9.0, 10.0, 10.0, 11.0, 11.0],
[12.0, 12.0, 13.0, 13.0, 14.0, 14.0],
[12.0, 12.0, 13.0, 13.0, 14.0, 14.0],
[15.0, 15.0, 16.0, 16.0, 17.0, 17.0],
[15.0, 15.0, 16.0, 16.0, 17.0, 17.0],
],
],
[
[
[18.0, 18.0, 19.0, 19.0, 20.0, 20.0],
[18.0, 18.0, 19.0, 19.0, 20.0, 20.0],
[21.0, 21.0, 22.0, 22.0, 23.0, 23.0],
[21.0, 21.0, 22.0, 22.0, 23.0, 23.0],
[24.0, 24.0, 25.0, 25.0, 26.0, 26.0],
[24.0, 24.0, 25.0, 25.0, 26.0, 26.0],
],
[
[27.0, 27.0, 28.0, 28.0, 29.0, 29.0],
[27.0, 27.0, 28.0, 28.0, 29.0, 29.0],
[30.0, 30.0, 31.0, 31.0, 32.0, 32.0],
[30.0, 30.0, 31.0, 31.0, 32.0, 32.0],
[33.0, 33.0, 34.0, 34.0, 35.0, 35.0],
[33.0, 33.0, 34.0, 34.0, 35.0, 35.0],
],
],
]
).transpose((0, 2, 3, 1))
expected_bilinear = mx.array(
[
[
[
[0.0, 0.4, 0.8, 1.2, 1.6, 2.0],
[1.2, 1.6, 2.0, 2.4, 2.8, 3.2],
[2.4, 2.8, 3.2, 3.6, 4.0, 4.4],
[3.6, 4.0, 4.4, 4.8, 5.2, 5.6],
[4.8, 5.2, 5.6, 6.0, 6.4, 6.8],
[6.0, 6.4, 6.8, 7.2, 7.6, 8.0],
],
[
[9.0, 9.4, 9.8, 10.2, 10.6, 11.0],
[10.2, 10.6, 11.0, 11.4, 11.8, 12.2],
[11.4, 11.8, 12.2, 12.6, 13.0, 13.4],
[12.6, 13.0, 13.4, 13.8, 14.2, 14.6],
[13.8, 14.2, 14.6, 15.0, 15.4, 15.8],
[15.0, 15.4, 15.8, 16.2, 16.6, 17.0],
],
],
[
[
[18.0, 18.4, 18.8, 19.2, 19.6, 20.0],
[19.2, 19.6, 20.0, 20.4, 20.8, 21.2],
[20.4, 20.8, 21.2, 21.6, 22.0, 22.4],
[21.6, 22.0, 22.4, 22.8, 23.2, 23.6],
[22.8, 23.2, 23.6, 24.0, 24.4, 24.8],
[24.0, 24.4, 24.8, 25.2, 25.6, 26.0],
],
[
[27.0, 27.4, 27.8, 28.2, 28.6, 29.0],
[28.2, 28.6, 29.0, 29.4, 29.8, 30.2],
[29.4, 29.8, 30.2, 30.6, 31.0, 31.4],
[30.6, 31.0, 31.4, 31.8, 32.2, 32.6],
[31.8, 32.2, 32.6, 33.0, 33.4, 33.8],
[33.0, 33.4, 33.8, 34.2, 34.6, 35.0],
],
],
]
).transpose((0, 2, 3, 1))
self.assertTrue(np.allclose(upsample_nearest(x), expected_nearest))
self.assertTrue(np.allclose(upsample_bilinear(x), expected_bilinear))
# Test different height and width scale_factor
b, h, w, c = 1, 2, 2, 2
x = mx.arange(b * h * w * c).reshape((b, c, h, w)).transpose((0, 2, 3, 1))
upsample_nearest = nn.Upsample(
scale_factor=(2, 3), mode="nearest", align_corners=True
)
upsample_bilinear = nn.Upsample(
scale_factor=(2, 3), mode="linear", align_corners=True
)
expected_nearest = mx.array(
[
[
[
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[2, 2, 2, 3, 3, 3],
[2, 2, 2, 3, 3, 3],
],
[
[4, 4, 4, 5, 5, 5],
[4, 4, 4, 5, 5, 5],
[6, 6, 6, 7, 7, 7],
[6, 6, 6, 7, 7, 7],
],
]
]
).transpose((0, 2, 3, 1))
expected_bilinear = mx.array(
[
[
[
[0, 0.2, 0.4, 0.6, 0.8, 1],
[0.666667, 0.866667, 1.06667, 1.26667, 1.46667, 1.66667],
[1.33333, 1.53333, 1.73333, 1.93333, 2.13333, 2.33333],
[2, 2.2, 2.4, 2.6, 2.8, 3],
],
[
[4, 4.2, 4.4, 4.6, 4.8, 5],
[4.66667, 4.86667, 5.06667, 5.26667, 5.46667, 5.66667],
[5.33333, 5.53333, 5.73333, 5.93333, 6.13333, 6.33333],
[6, 6.2, 6.4, 6.6, 6.8, 7],
],
]
]
).transpose((0, 2, 3, 1))
self.assertTrue(np.allclose(upsample_nearest(x), expected_nearest))
self.assertTrue(np.allclose(upsample_bilinear(x), expected_bilinear))
# Test repr
self.assertEqual(
str(nn.Upsample(scale_factor=2)),
"Upsample(scale_factor=2.0, mode='nearest', align_corners=False)",
)
self.assertEqual(
str(nn.Upsample(scale_factor=(2, 3))),
"Upsample(scale_factor=(2.0, 3.0), mode='nearest', align_corners=False)",
)
def test_pooling(self):
# Test 1d pooling
x = mx.array(