Nearest upsample (#2202)

This commit is contained in:
Angelos Katharopoulos
2025-05-19 11:23:38 -07:00
committed by GitHub
parent 237f9e58a8
commit 0359bf02c9
2 changed files with 16 additions and 6 deletions

View File

@@ -25,7 +25,16 @@ def _scaled_indices(N, scale, align_corners, dim, ndims):
def _nearest_indices(N, scale, dim, ndims):
return _scaled_indices(N, scale, True, dim, ndims).astype(mx.uint32)
M = int(scale * N)
indices = mx.arange(M, dtype=mx.float32)
if M > N:
indices = (indices + 0.5) * (N / M) - 0.5
indices = indices.round()
else:
indices = indices * (N / M)
shape = [1] * ndims
shape[dim] = -1
return indices.astype(mx.uint32).reshape(shape)
def _linear_indices(N, scale, align_corners, dim, ndims):