diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index c2cad615e..bde148fe8 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -11,7 +11,7 @@ MLX was developed with contributions from the following individuals: - Juarez Bochi: Fixed bug in cross attention. - Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example. - Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support. -- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented ``MaxPool1d``, ``MaxPool2d``, ``AvgPool1d``, ``AvgPool2d``. +- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented pooling layers and ``Upsample``. - Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops. @@ -253,4 +253,4 @@ Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and -limitations under the License. \ No newline at end of file +limitations under the License. diff --git a/docs/src/python/nn/layers.rst b/docs/src/python/nn/layers.rst index 0f5fca9db..f6755e8fe 100644 --- a/docs/src/python/nn/layers.rst +++ b/docs/src/python/nn/layers.rst @@ -40,3 +40,4 @@ Layers Softshrink Step Transformer + Upsample \ No newline at end of file diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index 207cb01b2..d992b0426 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -67,3 +67,4 @@ from mlx.nn.layers.transformer import ( TransformerEncoder, TransformerEncoderLayer, ) +from mlx.nn.layers.upsample import Upsample diff --git a/python/mlx/nn/layers/upsample.py b/python/mlx/nn/layers/upsample.py new file mode 100644 index 000000000..aac70e24a --- /dev/null +++ b/python/mlx/nn/layers/upsample.py @@ -0,0 +1,205 @@ +# Copyright © 2023-2024 Apple Inc. + +import operator +from functools import reduce +from itertools import product +from typing import Literal, Tuple, Union + +import mlx.core as mx +from mlx.nn.layers.base import Module + + +def _scaled_indices(N, scale, align_corners, dim, ndims): + M = int(scale * N) + if align_corners: + indices = mx.arange(M, dtype=mx.float32) * ((N - 1) / (M - 1)) + else: + step = 1 / scale + start = ((M - 1) * step - N + 1) / 2 + indices = mx.arange(M, dtype=mx.float32) * step - start + indices = mx.clip(indices, 0, N - 1) + shape = [1] * ndims + shape[dim] = -1 + + return indices.reshape(shape) + + +def _nearest_indices(N, scale, dim, ndims): + return _scaled_indices(N, scale, True, dim, ndims).astype(mx.int32) + + +def _linear_indices(N, scale, align_corners, dim, ndims): + indices = _scaled_indices(N, scale, align_corners, dim, ndims) + indices_l = mx.floor(indices) + indices_r = mx.ceil(indices) + weight = indices - indices_l + weight = mx.expand_dims(weight, -1) + + return ( + (indices_l.astype(mx.int32), 1 - weight), + (indices_r.astype(mx.int32), weight), + ) + + +def upsample_nearest(x: mx.array, scale_factor: Tuple): + dims = x.ndim - 2 + if dims != len(scale_factor): + raise ValueError("A scale needs to be provided for each spatial dimension") + + # Integer scale_factors means we can simply expand-broadcast and reshape + if tuple(map(int, scale_factor)) == scale_factor: + shape = list(x.shape) + for d in range(dims): + shape.insert(2 + 2 * d, 1) + x = x.reshape(shape) + for d in range(dims): + shape[2 + 2 * d] = int(scale_factor[d]) + x = mx.broadcast_to(x, shape) + for d in range(dims): + shape[d + 1] *= shape[d + 2] + shape.pop(d + 2) + x = x.reshape(shape) + return x + + else: + B, *N, C = x.shape + indices = [slice(None)] + for i, (n, s) in enumerate(zip(N, scale_factor)): + indices.append(_nearest_indices(n, s, i, dims)) + indices = tuple(indices) + + return x[indices] + + +def upsample_linear(x: mx.array, scale_factor: Tuple, align_corners: bool = False): + dims = x.ndim - 2 + if dims != len(scale_factor): + raise ValueError("A scale needs to be provided for each spatial dimension") + + B, *N, C = x.shape + + # Compute the sampling grid + indices = [] + for i, (n, s) in enumerate(zip(N, scale_factor)): + indices.append(_linear_indices(n, s, align_corners, i, dims)) + + # Sample and compute the weights + samples = [] + weights = [] + for idx_weight in product(*indices): + idx, weight = zip(*idx_weight) + samples.append(x[(slice(None),) + idx]) + weights.append(reduce(operator.mul, weight)) + + # Interpolate + return sum(wi * xi for wi, xi in zip(weights, samples)) + + +class Upsample(Module): + r"""Upsample the input signal spatially. + + The spatial dimensions are by convention dimensions ``1`` to ``x.ndim - + 2``. The first is the batch dimension and the last is the feature + dimension. + + For example, an audio signal would be 3D with 1 spatial dimension, an image + 4D with 2 and so on and so forth. + + There are two upsampling algorithms implemented nearest neighbor upsampling + and linear interpolation. Both can be applied to any number of spatial + dimensions and the linear interpolation will be bilinear, trilinear etc + when applied to more than one spatial dimension. + + .. note:: + When using one of the linear interpolation modes the ``align_corners`` + argument changes how the corners are treated in the input image. If + ``align_corners=True`` then the top and left edge of the input and + output will be matching as will the bottom right edge. + + Parameters: + scale_factor (float or tuple): The multiplier for the spatial size. + If a ``float`` is provided, it is the multiplier for all spatial dimensions. + Otherwise, the number of scale factors provided must match the + number of spatial dimensions. + mode (str, optional): The upsampling algorithm, either ``"nearest"`` or + ``"linear"``. Default: ``"nearest"``. + align_corners (bool, optional): Changes the way the corners are treated + during ``"linear"`` upsampling. See the note above and the + examples below for more details. Default: ``False``. + + Examples: + >>> import mlx.core as mx + >>> import mlx.nn as nn + >>> x = mx.arange(1, 5).reshape((1, 2, 2, 1)) + >>> x + array([[[[1], + [2]], + [[3], + [4]]]], dtype=int32) + >>> n = nn.Upsample(scale_factor=2, mode='nearest') + >>> n(x).squeeze() + array([[1, 1, 2, 2], + [1, 1, 2, 2], + [3, 3, 4, 4], + [3, 3, 4, 4]], dtype=int32) + >>> b = nn.Upsample(scale_factor=2, mode='linear') + >>> b(x).squeeze() + array([[1, 1.25, 1.75, 2], + [1.5, 1.75, 2.25, 2.5], + [2.5, 2.75, 3.25, 3.5], + [3, 3.25, 3.75, 4]], dtype=float32) + >>> b = nn.Upsample(scale_factor=2, mode='linear', align_corners=True) + >>> b(x).squeeze() + array([[1, 1.33333, 1.66667, 2], + [1.66667, 2, 2.33333, 2.66667], + [2.33333, 2.66667, 3, 3.33333], + [3, 3.33333, 3.66667, 4]], dtype=float32) + """ + + def __init__( + self, + scale_factor: Union[float, Tuple], + mode: Literal["nearest", "linear"] = "nearest", + align_corners: bool = False, + ): + super().__init__() + if mode not in ["nearest", "linear"]: + raise ValueError(f"[Upsample] Got unsupported upsampling algorithm: {mode}") + if isinstance(scale_factor, (list, tuple)): + self.scale_factor = tuple(map(float, scale_factor)) + else: + self.scale_factor = float(scale_factor) + self.mode = mode + self.align_corners = align_corners + + def _extra_repr(self) -> str: + return ( + f"scale_factor={self.scale_factor}, mode={self.mode!r}, " + f"align_corners={self.align_corners}" + ) + + def __call__(self, x: mx.array) -> mx.array: + dims = x.ndim - 2 + if dims <= 0: + raise ValueError( + f"[Upsample] The input should have at least 1 spatial " + f"dimension which means it should be at least 3D but " + f"{x.ndim}D was provided" + ) + + scale_factor = self.scale_factor + if isinstance(scale_factor, tuple): + if len(scale_factor) != dims: + raise ValueError( + f"[Upsample] One scale per spatial dimension is required but " + f"scale_factor={scale_factor} and the number of spatial " + f"dimensions were {dims}" + ) + else: + scale_factor = (scale_factor,) * dims + + if self.mode == "nearest": + return upsample_nearest(x, scale_factor) + + else: + return upsample_linear(x, scale_factor, self.align_corners) diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index eaaf3bb9c..2c8346179 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -1,4 +1,4 @@ -# Copyright © 2023 Apple Inc. +# Copyright © 2023-2024 Apple Inc. import os import tempfile @@ -8,7 +8,7 @@ import mlx.core as mx import mlx.nn as nn import mlx_tests import numpy as np -from mlx.utils import tree_flatten, tree_map, tree_unflatten +from mlx.utils import tree_flatten, tree_map class TestBase(mlx_tests.MLXTestCase): @@ -905,6 +905,228 @@ class TestLayers(mlx_tests.MLXTestCase): self.assertTrue(y.shape, x.shape) self.assertTrue(y.dtype, mx.float16) + def test_upsample(self): + b, h, w, c = 1, 2, 2, 1 + scale_factor = 2 + upsample_nearest = nn.Upsample( + scale_factor=scale_factor, mode="nearest", align_corners=True + ) + upsample_bilinear = nn.Upsample( + scale_factor=scale_factor, mode="linear", align_corners=True + ) + upsample_nearest = nn.Upsample( + scale_factor=scale_factor, mode="nearest", align_corners=True + ) + upsample_bilinear_no_align_corners = nn.Upsample( + scale_factor=scale_factor, mode="linear", align_corners=False + ) + upsample_nearest_no_align_corners = nn.Upsample( + scale_factor=scale_factor, mode="nearest", align_corners=False + ) + # Test single feature map, align corners + x = mx.arange(b * h * w * c).reshape((b, c, h, w)).transpose((0, 2, 3, 1)) + expected_nearest = mx.array( + [[[[0, 0, 1, 1], [0, 0, 1, 1], [2, 2, 3, 3], [2, 2, 3, 3]]]] + ).transpose((0, 2, 3, 1)) + expected_bilinear = mx.array( + [ + [ + [ + [0, 0.333333, 0.666667, 1], + [0.666667, 1, 1.33333, 1.66667], + [1.33333, 1.66667, 2, 2.33333], + [2, 2.33333, 2.66667, 3], + ] + ] + ] + ).transpose((0, 2, 3, 1)) + # Test single feature map, no align corners + x = ( + mx.arange(1, b * h * w * c + 1) + .reshape((b, c, h, w)) + .transpose((0, 2, 3, 1)) + ) + expected_bilinear_no_align_corners = mx.array( + [ + [ + [ + [1.0000, 1.2500, 1.7500, 2.0000], + [1.5000, 1.7500, 2.2500, 2.5000], + [2.5000, 2.7500, 3.2500, 3.5000], + [3.0000, 3.2500, 3.7500, 4.0000], + ] + ] + ] + ).transpose((0, 2, 3, 1)) + expected_nearest_no_align_corners = mx.array( + [[[[1, 1, 2, 2], [1, 1, 2, 2], [3, 3, 4, 4], [3, 3, 4, 4]]]] + ).transpose((0, 2, 3, 1)) + self.assertTrue( + np.allclose( + upsample_nearest_no_align_corners(x), expected_nearest_no_align_corners + ) + ) + self.assertTrue( + np.allclose( + upsample_bilinear_no_align_corners(x), + expected_bilinear_no_align_corners, + ) + ) + + # Test a more complex batch + b, h, w, c = 2, 3, 3, 2 + scale_factor = 2 + x = mx.arange((b * h * w * c)).reshape((b, c, h, w)).transpose((0, 2, 3, 1)) + + upsample_nearest = nn.Upsample( + scale_factor=scale_factor, mode="nearest", align_corners=True + ) + upsample_bilinear = nn.Upsample( + scale_factor=scale_factor, mode="linear", align_corners=True + ) + + expected_nearest = mx.array( + [ + [ + [ + [0.0, 0.0, 1.0, 1.0, 2.0, 2.0], + [0.0, 0.0, 1.0, 1.0, 2.0, 2.0], + [3.0, 3.0, 4.0, 4.0, 5.0, 5.0], + [3.0, 3.0, 4.0, 4.0, 5.0, 5.0], + [6.0, 6.0, 7.0, 7.0, 8.0, 8.0], + [6.0, 6.0, 7.0, 7.0, 8.0, 8.0], + ], + [ + [9.0, 9.0, 10.0, 10.0, 11.0, 11.0], + [9.0, 9.0, 10.0, 10.0, 11.0, 11.0], + [12.0, 12.0, 13.0, 13.0, 14.0, 14.0], + [12.0, 12.0, 13.0, 13.0, 14.0, 14.0], + [15.0, 15.0, 16.0, 16.0, 17.0, 17.0], + [15.0, 15.0, 16.0, 16.0, 17.0, 17.0], + ], + ], + [ + [ + [18.0, 18.0, 19.0, 19.0, 20.0, 20.0], + [18.0, 18.0, 19.0, 19.0, 20.0, 20.0], + [21.0, 21.0, 22.0, 22.0, 23.0, 23.0], + [21.0, 21.0, 22.0, 22.0, 23.0, 23.0], + [24.0, 24.0, 25.0, 25.0, 26.0, 26.0], + [24.0, 24.0, 25.0, 25.0, 26.0, 26.0], + ], + [ + [27.0, 27.0, 28.0, 28.0, 29.0, 29.0], + [27.0, 27.0, 28.0, 28.0, 29.0, 29.0], + [30.0, 30.0, 31.0, 31.0, 32.0, 32.0], + [30.0, 30.0, 31.0, 31.0, 32.0, 32.0], + [33.0, 33.0, 34.0, 34.0, 35.0, 35.0], + [33.0, 33.0, 34.0, 34.0, 35.0, 35.0], + ], + ], + ] + ).transpose((0, 2, 3, 1)) + expected_bilinear = mx.array( + [ + [ + [ + [0.0, 0.4, 0.8, 1.2, 1.6, 2.0], + [1.2, 1.6, 2.0, 2.4, 2.8, 3.2], + [2.4, 2.8, 3.2, 3.6, 4.0, 4.4], + [3.6, 4.0, 4.4, 4.8, 5.2, 5.6], + [4.8, 5.2, 5.6, 6.0, 6.4, 6.8], + [6.0, 6.4, 6.8, 7.2, 7.6, 8.0], + ], + [ + [9.0, 9.4, 9.8, 10.2, 10.6, 11.0], + [10.2, 10.6, 11.0, 11.4, 11.8, 12.2], + [11.4, 11.8, 12.2, 12.6, 13.0, 13.4], + [12.6, 13.0, 13.4, 13.8, 14.2, 14.6], + [13.8, 14.2, 14.6, 15.0, 15.4, 15.8], + [15.0, 15.4, 15.8, 16.2, 16.6, 17.0], + ], + ], + [ + [ + [18.0, 18.4, 18.8, 19.2, 19.6, 20.0], + [19.2, 19.6, 20.0, 20.4, 20.8, 21.2], + [20.4, 20.8, 21.2, 21.6, 22.0, 22.4], + [21.6, 22.0, 22.4, 22.8, 23.2, 23.6], + [22.8, 23.2, 23.6, 24.0, 24.4, 24.8], + [24.0, 24.4, 24.8, 25.2, 25.6, 26.0], + ], + [ + [27.0, 27.4, 27.8, 28.2, 28.6, 29.0], + [28.2, 28.6, 29.0, 29.4, 29.8, 30.2], + [29.4, 29.8, 30.2, 30.6, 31.0, 31.4], + [30.6, 31.0, 31.4, 31.8, 32.2, 32.6], + [31.8, 32.2, 32.6, 33.0, 33.4, 33.8], + [33.0, 33.4, 33.8, 34.2, 34.6, 35.0], + ], + ], + ] + ).transpose((0, 2, 3, 1)) + self.assertTrue(np.allclose(upsample_nearest(x), expected_nearest)) + self.assertTrue(np.allclose(upsample_bilinear(x), expected_bilinear)) + + # Test different height and width scale_factor + b, h, w, c = 1, 2, 2, 2 + x = mx.arange(b * h * w * c).reshape((b, c, h, w)).transpose((0, 2, 3, 1)) + upsample_nearest = nn.Upsample( + scale_factor=(2, 3), mode="nearest", align_corners=True + ) + upsample_bilinear = nn.Upsample( + scale_factor=(2, 3), mode="linear", align_corners=True + ) + + expected_nearest = mx.array( + [ + [ + [ + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [2, 2, 2, 3, 3, 3], + [2, 2, 2, 3, 3, 3], + ], + [ + [4, 4, 4, 5, 5, 5], + [4, 4, 4, 5, 5, 5], + [6, 6, 6, 7, 7, 7], + [6, 6, 6, 7, 7, 7], + ], + ] + ] + ).transpose((0, 2, 3, 1)) + expected_bilinear = mx.array( + [ + [ + [ + [0, 0.2, 0.4, 0.6, 0.8, 1], + [0.666667, 0.866667, 1.06667, 1.26667, 1.46667, 1.66667], + [1.33333, 1.53333, 1.73333, 1.93333, 2.13333, 2.33333], + [2, 2.2, 2.4, 2.6, 2.8, 3], + ], + [ + [4, 4.2, 4.4, 4.6, 4.8, 5], + [4.66667, 4.86667, 5.06667, 5.26667, 5.46667, 5.66667], + [5.33333, 5.53333, 5.73333, 5.93333, 6.13333, 6.33333], + [6, 6.2, 6.4, 6.6, 6.8, 7], + ], + ] + ] + ).transpose((0, 2, 3, 1)) + self.assertTrue(np.allclose(upsample_nearest(x), expected_nearest)) + self.assertTrue(np.allclose(upsample_bilinear(x), expected_bilinear)) + + # Test repr + self.assertEqual( + str(nn.Upsample(scale_factor=2)), + "Upsample(scale_factor=2.0, mode='nearest', align_corners=False)", + ) + self.assertEqual( + str(nn.Upsample(scale_factor=(2, 3))), + "Upsample(scale_factor=(2.0, 3.0), mode='nearest', align_corners=False)", + ) + def test_pooling(self): # Test 1d pooling x = mx.array(