mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 18:28:12 +08:00
Nearest upsample (#2202)
This commit is contained in:

committed by
GitHub

parent
237f9e58a8
commit
0359bf02c9
@@ -25,7 +25,16 @@ def _scaled_indices(N, scale, align_corners, dim, ndims):
|
||||
|
||||
|
||||
def _nearest_indices(N, scale, dim, ndims):
|
||||
return _scaled_indices(N, scale, True, dim, ndims).astype(mx.uint32)
|
||||
M = int(scale * N)
|
||||
indices = mx.arange(M, dtype=mx.float32)
|
||||
if M > N:
|
||||
indices = (indices + 0.5) * (N / M) - 0.5
|
||||
indices = indices.round()
|
||||
else:
|
||||
indices = indices * (N / M)
|
||||
shape = [1] * ndims
|
||||
shape[dim] = -1
|
||||
return indices.astype(mx.uint32).reshape(shape)
|
||||
|
||||
|
||||
def _linear_indices(N, scale, align_corners, dim, ndims):
|
||||
|
Reference in New Issue
Block a user