Upsample2d (#414)

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
Gabrijel Boduljak
2024-02-23 18:55:04 +01:00
committed by GitHub
parent d729a1991b
commit 22364c40b7
5 changed files with 433 additions and 4 deletions

View File

@@ -1,4 +1,4 @@
# Copyright © 2023 Apple Inc.
# Copyright © 2023-2024 Apple Inc.
import os
import tempfile
@@ -8,7 +8,7 @@ import mlx.core as mx
import mlx.nn as nn
import mlx_tests
import numpy as np
from mlx.utils import tree_flatten, tree_map, tree_unflatten
from mlx.utils import tree_flatten, tree_map
class TestBase(mlx_tests.MLXTestCase):
@@ -905,6 +905,228 @@ class TestLayers(mlx_tests.MLXTestCase):
self.assertTrue(y.shape, x.shape)
self.assertTrue(y.dtype, mx.float16)
def test_upsample(self):
b, h, w, c = 1, 2, 2, 1
scale_factor = 2
upsample_nearest = nn.Upsample(
scale_factor=scale_factor, mode="nearest", align_corners=True
)
upsample_bilinear = nn.Upsample(
scale_factor=scale_factor, mode="linear", align_corners=True
)
upsample_nearest = nn.Upsample(
scale_factor=scale_factor, mode="nearest", align_corners=True
)
upsample_bilinear_no_align_corners = nn.Upsample(
scale_factor=scale_factor, mode="linear", align_corners=False
)
upsample_nearest_no_align_corners = nn.Upsample(
scale_factor=scale_factor, mode="nearest", align_corners=False
)
# Test single feature map, align corners
x = mx.arange(b * h * w * c).reshape((b, c, h, w)).transpose((0, 2, 3, 1))
expected_nearest = mx.array(
[[[[0, 0, 1, 1], [0, 0, 1, 1], [2, 2, 3, 3], [2, 2, 3, 3]]]]
).transpose((0, 2, 3, 1))
expected_bilinear = mx.array(
[
[
[
[0, 0.333333, 0.666667, 1],
[0.666667, 1, 1.33333, 1.66667],
[1.33333, 1.66667, 2, 2.33333],
[2, 2.33333, 2.66667, 3],
]
]
]
).transpose((0, 2, 3, 1))
# Test single feature map, no align corners
x = (
mx.arange(1, b * h * w * c + 1)
.reshape((b, c, h, w))
.transpose((0, 2, 3, 1))
)
expected_bilinear_no_align_corners = mx.array(
[
[
[
[1.0000, 1.2500, 1.7500, 2.0000],
[1.5000, 1.7500, 2.2500, 2.5000],
[2.5000, 2.7500, 3.2500, 3.5000],
[3.0000, 3.2500, 3.7500, 4.0000],
]
]
]
).transpose((0, 2, 3, 1))
expected_nearest_no_align_corners = mx.array(
[[[[1, 1, 2, 2], [1, 1, 2, 2], [3, 3, 4, 4], [3, 3, 4, 4]]]]
).transpose((0, 2, 3, 1))
self.assertTrue(
np.allclose(
upsample_nearest_no_align_corners(x), expected_nearest_no_align_corners
)
)
self.assertTrue(
np.allclose(
upsample_bilinear_no_align_corners(x),
expected_bilinear_no_align_corners,
)
)
# Test a more complex batch
b, h, w, c = 2, 3, 3, 2
scale_factor = 2
x = mx.arange((b * h * w * c)).reshape((b, c, h, w)).transpose((0, 2, 3, 1))
upsample_nearest = nn.Upsample(
scale_factor=scale_factor, mode="nearest", align_corners=True
)
upsample_bilinear = nn.Upsample(
scale_factor=scale_factor, mode="linear", align_corners=True
)
expected_nearest = mx.array(
[
[
[
[0.0, 0.0, 1.0, 1.0, 2.0, 2.0],
[0.0, 0.0, 1.0, 1.0, 2.0, 2.0],
[3.0, 3.0, 4.0, 4.0, 5.0, 5.0],
[3.0, 3.0, 4.0, 4.0, 5.0, 5.0],
[6.0, 6.0, 7.0, 7.0, 8.0, 8.0],
[6.0, 6.0, 7.0, 7.0, 8.0, 8.0],
],
[
[9.0, 9.0, 10.0, 10.0, 11.0, 11.0],
[9.0, 9.0, 10.0, 10.0, 11.0, 11.0],
[12.0, 12.0, 13.0, 13.0, 14.0, 14.0],
[12.0, 12.0, 13.0, 13.0, 14.0, 14.0],
[15.0, 15.0, 16.0, 16.0, 17.0, 17.0],
[15.0, 15.0, 16.0, 16.0, 17.0, 17.0],
],
],
[
[
[18.0, 18.0, 19.0, 19.0, 20.0, 20.0],
[18.0, 18.0, 19.0, 19.0, 20.0, 20.0],
[21.0, 21.0, 22.0, 22.0, 23.0, 23.0],
[21.0, 21.0, 22.0, 22.0, 23.0, 23.0],
[24.0, 24.0, 25.0, 25.0, 26.0, 26.0],
[24.0, 24.0, 25.0, 25.0, 26.0, 26.0],
],
[
[27.0, 27.0, 28.0, 28.0, 29.0, 29.0],
[27.0, 27.0, 28.0, 28.0, 29.0, 29.0],
[30.0, 30.0, 31.0, 31.0, 32.0, 32.0],
[30.0, 30.0, 31.0, 31.0, 32.0, 32.0],
[33.0, 33.0, 34.0, 34.0, 35.0, 35.0],
[33.0, 33.0, 34.0, 34.0, 35.0, 35.0],
],
],
]
).transpose((0, 2, 3, 1))
expected_bilinear = mx.array(
[
[
[
[0.0, 0.4, 0.8, 1.2, 1.6, 2.0],
[1.2, 1.6, 2.0, 2.4, 2.8, 3.2],
[2.4, 2.8, 3.2, 3.6, 4.0, 4.4],
[3.6, 4.0, 4.4, 4.8, 5.2, 5.6],
[4.8, 5.2, 5.6, 6.0, 6.4, 6.8],
[6.0, 6.4, 6.8, 7.2, 7.6, 8.0],
],
[
[9.0, 9.4, 9.8, 10.2, 10.6, 11.0],
[10.2, 10.6, 11.0, 11.4, 11.8, 12.2],
[11.4, 11.8, 12.2, 12.6, 13.0, 13.4],
[12.6, 13.0, 13.4, 13.8, 14.2, 14.6],
[13.8, 14.2, 14.6, 15.0, 15.4, 15.8],
[15.0, 15.4, 15.8, 16.2, 16.6, 17.0],
],
],
[
[
[18.0, 18.4, 18.8, 19.2, 19.6, 20.0],
[19.2, 19.6, 20.0, 20.4, 20.8, 21.2],
[20.4, 20.8, 21.2, 21.6, 22.0, 22.4],
[21.6, 22.0, 22.4, 22.8, 23.2, 23.6],
[22.8, 23.2, 23.6, 24.0, 24.4, 24.8],
[24.0, 24.4, 24.8, 25.2, 25.6, 26.0],
],
[
[27.0, 27.4, 27.8, 28.2, 28.6, 29.0],
[28.2, 28.6, 29.0, 29.4, 29.8, 30.2],
[29.4, 29.8, 30.2, 30.6, 31.0, 31.4],
[30.6, 31.0, 31.4, 31.8, 32.2, 32.6],
[31.8, 32.2, 32.6, 33.0, 33.4, 33.8],
[33.0, 33.4, 33.8, 34.2, 34.6, 35.0],
],
],
]
).transpose((0, 2, 3, 1))
self.assertTrue(np.allclose(upsample_nearest(x), expected_nearest))
self.assertTrue(np.allclose(upsample_bilinear(x), expected_bilinear))
# Test different height and width scale_factor
b, h, w, c = 1, 2, 2, 2
x = mx.arange(b * h * w * c).reshape((b, c, h, w)).transpose((0, 2, 3, 1))
upsample_nearest = nn.Upsample(
scale_factor=(2, 3), mode="nearest", align_corners=True
)
upsample_bilinear = nn.Upsample(
scale_factor=(2, 3), mode="linear", align_corners=True
)
expected_nearest = mx.array(
[
[
[
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[2, 2, 2, 3, 3, 3],
[2, 2, 2, 3, 3, 3],
],
[
[4, 4, 4, 5, 5, 5],
[4, 4, 4, 5, 5, 5],
[6, 6, 6, 7, 7, 7],
[6, 6, 6, 7, 7, 7],
],
]
]
).transpose((0, 2, 3, 1))
expected_bilinear = mx.array(
[
[
[
[0, 0.2, 0.4, 0.6, 0.8, 1],
[0.666667, 0.866667, 1.06667, 1.26667, 1.46667, 1.66667],
[1.33333, 1.53333, 1.73333, 1.93333, 2.13333, 2.33333],
[2, 2.2, 2.4, 2.6, 2.8, 3],
],
[
[4, 4.2, 4.4, 4.6, 4.8, 5],
[4.66667, 4.86667, 5.06667, 5.26667, 5.46667, 5.66667],
[5.33333, 5.53333, 5.73333, 5.93333, 6.13333, 6.33333],
[6, 6.2, 6.4, 6.6, 6.8, 7],
],
]
]
).transpose((0, 2, 3, 1))
self.assertTrue(np.allclose(upsample_nearest(x), expected_nearest))
self.assertTrue(np.allclose(upsample_bilinear(x), expected_bilinear))
# Test repr
self.assertEqual(
str(nn.Upsample(scale_factor=2)),
"Upsample(scale_factor=2.0, mode='nearest', align_corners=False)",
)
self.assertEqual(
str(nn.Upsample(scale_factor=(2, 3))),
"Upsample(scale_factor=(2.0, 3.0), mode='nearest', align_corners=False)",
)
def test_pooling(self):
# Test 1d pooling
x = mx.array(