diff --git a/python/mlx/nn/layers/upsample.py b/python/mlx/nn/layers/upsample.py index 1f2ffd3da..6260a4add 100644 --- a/python/mlx/nn/layers/upsample.py +++ b/python/mlx/nn/layers/upsample.py @@ -12,7 +12,7 @@ from mlx.nn.layers.base import Module def _scaled_indices(N, scale, align_corners, dim, ndims): M = int(scale * N) if align_corners: - indices = mx.arange(M, dtype=mx.float32) * ((N - 1) / (M - 1)) + indices = ((mx.arange(M, dtype=mx.float32) + 0.5) * (N / M) - 0.5).round() else: step = 1 / scale start = ((M - 1) * step - N + 1) / 2