More fixes for nearest

This commit is contained in:
Angelos Katharopoulos
2025-05-18 23:41:37 -07:00
parent f1d29276b0
commit cdd3e9c573
2 changed files with 17 additions and 4 deletions

View File

@@ -51,6 +51,7 @@ class TestUpsample(mlx_tests.MLXTestCase):
align_corners=align_corner,
)(in_mx)
mode_pt = {
"nearest": "nearest",
"linear": "bilinear",
"cubic": "bicubic",
}[mode]
@@ -58,7 +59,7 @@ class TestUpsample(mlx_tests.MLXTestCase):
in_pt,
scale_factor=scale_factor,
mode=mode_pt,
align_corners=align_corner,
align_corners=align_corner if mode != "nearest" else None,
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True)
self.assertEqual(out_pt.shape, out_mx.shape)
@@ -76,14 +77,17 @@ class TestUpsample(mlx_tests.MLXTestCase):
((4, 4), (0.5, 0.5)),
((7, 7), (2.0, 2.0)),
((10, 10), (0.2, 0.2)),
((10, 10), (0.3, 0.3)),
((11, 21), (3.0, 3.0)),
((11, 21), (3.0, 2.0)),
):
# only test linear and cubic interpolation
# there will be numerical difference in nearest
# due to different indices selection.
for mode in ("cubic", "linear"):
for mode in ("cubic", "linear", "nearest"):
for align_corner in (False, True):
if mode == "nearest" and align_corner:
continue
run_upsample(
N,
C,