From f1d29276b0cdebde8a74d8bc19941c08eccf446a Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Sun, 18 May 2025 19:43:58 -0700 Subject: [PATCH] Fix --- python/mlx/nn/layers/upsample.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/mlx/nn/layers/upsample.py b/python/mlx/nn/layers/upsample.py index 6260a4add..ea61ac8fc 100644 --- a/python/mlx/nn/layers/upsample.py +++ b/python/mlx/nn/layers/upsample.py @@ -12,7 +12,7 @@ from mlx.nn.layers.base import Module def _scaled_indices(N, scale, align_corners, dim, ndims): M = int(scale * N) if align_corners: - indices = ((mx.arange(M, dtype=mx.float32) + 0.5) * (N / M) - 0.5).round() + indices = (mx.arange(M, dtype=mx.float32) + 0.5) * (N / M) - 0.5 else: step = 1 / scale start = ((M - 1) * step - N + 1) / 2 @@ -25,7 +25,7 @@ def _scaled_indices(N, scale, align_corners, dim, ndims): def _nearest_indices(N, scale, dim, ndims): - return _scaled_indices(N, scale, True, dim, ndims).astype(mx.uint32) + return _scaled_indices(N, scale, True, dim, ndims).round().astype(mx.uint32) def _linear_indices(N, scale, align_corners, dim, ndims):