mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Nearest upsample (#2202)
This commit is contained in:
parent
237f9e58a8
commit
0359bf02c9
@ -25,7 +25,16 @@ def _scaled_indices(N, scale, align_corners, dim, ndims):
|
|||||||
|
|
||||||
|
|
||||||
def _nearest_indices(N, scale, dim, ndims):
|
def _nearest_indices(N, scale, dim, ndims):
|
||||||
return _scaled_indices(N, scale, True, dim, ndims).astype(mx.uint32)
|
M = int(scale * N)
|
||||||
|
indices = mx.arange(M, dtype=mx.float32)
|
||||||
|
if M > N:
|
||||||
|
indices = (indices + 0.5) * (N / M) - 0.5
|
||||||
|
indices = indices.round()
|
||||||
|
else:
|
||||||
|
indices = indices * (N / M)
|
||||||
|
shape = [1] * ndims
|
||||||
|
shape[dim] = -1
|
||||||
|
return indices.astype(mx.uint32).reshape(shape)
|
||||||
|
|
||||||
|
|
||||||
def _linear_indices(N, scale, align_corners, dim, ndims):
|
def _linear_indices(N, scale, align_corners, dim, ndims):
|
||||||
|
@ -51,6 +51,7 @@ class TestUpsample(mlx_tests.MLXTestCase):
|
|||||||
align_corners=align_corner,
|
align_corners=align_corner,
|
||||||
)(in_mx)
|
)(in_mx)
|
||||||
mode_pt = {
|
mode_pt = {
|
||||||
|
"nearest": "nearest",
|
||||||
"linear": "bilinear",
|
"linear": "bilinear",
|
||||||
"cubic": "bicubic",
|
"cubic": "bicubic",
|
||||||
}[mode]
|
}[mode]
|
||||||
@ -58,7 +59,7 @@ class TestUpsample(mlx_tests.MLXTestCase):
|
|||||||
in_pt,
|
in_pt,
|
||||||
scale_factor=scale_factor,
|
scale_factor=scale_factor,
|
||||||
mode=mode_pt,
|
mode=mode_pt,
|
||||||
align_corners=align_corner,
|
align_corners=align_corner if mode != "nearest" else None,
|
||||||
)
|
)
|
||||||
out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True)
|
out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True)
|
||||||
self.assertEqual(out_pt.shape, out_mx.shape)
|
self.assertEqual(out_pt.shape, out_mx.shape)
|
||||||
@ -76,14 +77,14 @@ class TestUpsample(mlx_tests.MLXTestCase):
|
|||||||
((4, 4), (0.5, 0.5)),
|
((4, 4), (0.5, 0.5)),
|
||||||
((7, 7), (2.0, 2.0)),
|
((7, 7), (2.0, 2.0)),
|
||||||
((10, 10), (0.2, 0.2)),
|
((10, 10), (0.2, 0.2)),
|
||||||
|
((10, 10), (0.3, 0.3)),
|
||||||
((11, 21), (3.0, 3.0)),
|
((11, 21), (3.0, 3.0)),
|
||||||
((11, 21), (3.0, 2.0)),
|
((11, 21), (3.0, 2.0)),
|
||||||
):
|
):
|
||||||
# only test linear and cubic interpolation
|
for mode in ("cubic", "linear", "nearest"):
|
||||||
# there will be numerical difference in nearest
|
|
||||||
# due to different indices selection.
|
|
||||||
for mode in ("cubic", "linear"):
|
|
||||||
for align_corner in (False, True):
|
for align_corner in (False, True):
|
||||||
|
if mode == "nearest" and align_corner:
|
||||||
|
continue
|
||||||
run_upsample(
|
run_upsample(
|
||||||
N,
|
N,
|
||||||
C,
|
C,
|
||||||
|
Loading…
Reference in New Issue
Block a user