diff --git a/python/mlx/nn/layers/upsample.py b/python/mlx/nn/layers/upsample.py index ea61ac8fc..e6bd282af 100644 --- a/python/mlx/nn/layers/upsample.py +++ b/python/mlx/nn/layers/upsample.py @@ -12,7 +12,7 @@ from mlx.nn.layers.base import Module def _scaled_indices(N, scale, align_corners, dim, ndims): M = int(scale * N) if align_corners: - indices = (mx.arange(M, dtype=mx.float32) + 0.5) * (N / M) - 0.5 + indices = mx.arange(M, dtype=mx.float32) * ((N - 1) / (M - 1)) else: step = 1 / scale start = ((M - 1) * step - N + 1) / 2 @@ -25,7 +25,16 @@ def _scaled_indices(N, scale, align_corners, dim, ndims): def _nearest_indices(N, scale, dim, ndims): - return _scaled_indices(N, scale, True, dim, ndims).round().astype(mx.uint32) + M = int(scale * N) + indices = mx.arange(M, dtype=mx.float32) + if M > N: + indices = (indices + 0.5) * (N / M) - 0.5 + indices = indices.round() + else: + indices = indices * (N / M) + shape = [1] * ndims + shape[dim] = -1 + return indices.astype(mx.uint32).reshape(shape) def _linear_indices(N, scale, align_corners, dim, ndims): diff --git a/python/tests/test_upsample.py b/python/tests/test_upsample.py index 402c7b0ca..b5314ae70 100644 --- a/python/tests/test_upsample.py +++ b/python/tests/test_upsample.py @@ -51,6 +51,7 @@ class TestUpsample(mlx_tests.MLXTestCase): align_corners=align_corner, )(in_mx) mode_pt = { + "nearest": "nearest", "linear": "bilinear", "cubic": "bicubic", }[mode] @@ -58,7 +59,7 @@ class TestUpsample(mlx_tests.MLXTestCase): in_pt, scale_factor=scale_factor, mode=mode_pt, - align_corners=align_corner, + align_corners=align_corner if mode != "nearest" else None, ) out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True) self.assertEqual(out_pt.shape, out_mx.shape) @@ -76,14 +77,17 @@ class TestUpsample(mlx_tests.MLXTestCase): ((4, 4), (0.5, 0.5)), ((7, 7), (2.0, 2.0)), ((10, 10), (0.2, 0.2)), + ((10, 10), (0.3, 0.3)), ((11, 21), (3.0, 3.0)), ((11, 21), (3.0, 2.0)), ): # only test linear and cubic interpolation # there will be numerical difference in nearest # due to different indices selection. - for mode in ("cubic", "linear"): + for mode in ("cubic", "linear", "nearest"): for align_corner in (False, True): + if mode == "nearest" and align_corner: + continue run_upsample( N, C,