Match PT for nearest upsample

This commit is contained in:
Angelos Katharopoulos 2025-05-18 19:33:47 -07:00
parent 8576e6fe36
commit 3cdc4a5f70

View File

@ -12,7 +12,7 @@ from mlx.nn.layers.base import Module
def _scaled_indices(N, scale, align_corners, dim, ndims):
M = int(scale * N)
if align_corners:
indices = mx.arange(M, dtype=mx.float32) * ((N - 1) / (M - 1))
indices = ((mx.arange(M, dtype=mx.float32) + 0.5) * (N / M) - 0.5).round()
else:
step = 1 / scale
start = ((M - 1) * step - N + 1) / 2