mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-14 12:31:13 +08:00
Fix
This commit is contained in:
parent
3cdc4a5f70
commit
f1d29276b0
@ -12,7 +12,7 @@ from mlx.nn.layers.base import Module
|
|||||||
def _scaled_indices(N, scale, align_corners, dim, ndims):
|
def _scaled_indices(N, scale, align_corners, dim, ndims):
|
||||||
M = int(scale * N)
|
M = int(scale * N)
|
||||||
if align_corners:
|
if align_corners:
|
||||||
indices = ((mx.arange(M, dtype=mx.float32) + 0.5) * (N / M) - 0.5).round()
|
indices = (mx.arange(M, dtype=mx.float32) + 0.5) * (N / M) - 0.5
|
||||||
else:
|
else:
|
||||||
step = 1 / scale
|
step = 1 / scale
|
||||||
start = ((M - 1) * step - N + 1) / 2
|
start = ((M - 1) * step - N + 1) / 2
|
||||||
@ -25,7 +25,7 @@ def _scaled_indices(N, scale, align_corners, dim, ndims):
|
|||||||
|
|
||||||
|
|
||||||
def _nearest_indices(N, scale, dim, ndims):
|
def _nearest_indices(N, scale, dim, ndims):
|
||||||
return _scaled_indices(N, scale, True, dim, ndims).astype(mx.uint32)
|
return _scaled_indices(N, scale, True, dim, ndims).round().astype(mx.uint32)
|
||||||
|
|
||||||
|
|
||||||
def _linear_indices(N, scale, align_corners, dim, ndims):
|
def _linear_indices(N, scale, align_corners, dim, ndims):
|
||||||
|
Loading…
Reference in New Issue
Block a user