Nearest upsample (#2202)

This commit is contained in:
Angelos Katharopoulos 2025-05-19 11:23:38 -07:00 committed by GitHub
parent 237f9e58a8
commit 0359bf02c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 16 additions and 6 deletions

View File

@ -25,7 +25,16 @@ def _scaled_indices(N, scale, align_corners, dim, ndims):
def _nearest_indices(N, scale, dim, ndims): def _nearest_indices(N, scale, dim, ndims):
return _scaled_indices(N, scale, True, dim, ndims).astype(mx.uint32) M = int(scale * N)
indices = mx.arange(M, dtype=mx.float32)
if M > N:
indices = (indices + 0.5) * (N / M) - 0.5
indices = indices.round()
else:
indices = indices * (N / M)
shape = [1] * ndims
shape[dim] = -1
return indices.astype(mx.uint32).reshape(shape)
def _linear_indices(N, scale, align_corners, dim, ndims): def _linear_indices(N, scale, align_corners, dim, ndims):

View File

@ -51,6 +51,7 @@ class TestUpsample(mlx_tests.MLXTestCase):
align_corners=align_corner, align_corners=align_corner,
)(in_mx) )(in_mx)
mode_pt = { mode_pt = {
"nearest": "nearest",
"linear": "bilinear", "linear": "bilinear",
"cubic": "bicubic", "cubic": "bicubic",
}[mode] }[mode]
@ -58,7 +59,7 @@ class TestUpsample(mlx_tests.MLXTestCase):
in_pt, in_pt,
scale_factor=scale_factor, scale_factor=scale_factor,
mode=mode_pt, mode=mode_pt,
align_corners=align_corner, align_corners=align_corner if mode != "nearest" else None,
) )
out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True) out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True)
self.assertEqual(out_pt.shape, out_mx.shape) self.assertEqual(out_pt.shape, out_mx.shape)
@ -76,14 +77,14 @@ class TestUpsample(mlx_tests.MLXTestCase):
((4, 4), (0.5, 0.5)), ((4, 4), (0.5, 0.5)),
((7, 7), (2.0, 2.0)), ((7, 7), (2.0, 2.0)),
((10, 10), (0.2, 0.2)), ((10, 10), (0.2, 0.2)),
((10, 10), (0.3, 0.3)),
((11, 21), (3.0, 3.0)), ((11, 21), (3.0, 3.0)),
((11, 21), (3.0, 2.0)), ((11, 21), (3.0, 2.0)),
): ):
# only test linear and cubic interpolation for mode in ("cubic", "linear", "nearest"):
# there will be numerical difference in nearest
# due to different indices selection.
for mode in ("cubic", "linear"):
for align_corner in (False, True): for align_corner in (False, True):
if mode == "nearest" and align_corner:
continue
run_upsample( run_upsample(
N, N,
C, C,