Add "edge" mode to mx.pad (#1309)

* Add edge padding mode

* fix pad in pooling

* string arg instead of enum
This commit is contained in:
Alex Barron
2024-08-06 11:23:10 -07:00
committed by GitHub
parent 8c9f0278b9
commit 635ccd9e25
6 changed files with 102 additions and 18 deletions

View File

@@ -1623,6 +1623,12 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(list(b_npy.shape), list(b_mlx.shape))
self.assertTrue(np.allclose(b_npy, b_mlx, atol=1e-6))
b_npy = np.pad(a_npy, pw, mode="edge")
b_mlx = mx.pad(a_mlx, pw, mode="edge")
self.assertEqual(list(b_npy.shape), list(b_mlx.shape))
self.assertTrue(np.allclose(b_npy, b_mlx, atol=1e-6))
a = mx.zeros((1, 1, 1))
self.assertEqual(mx.pad(a, 1).shape, (3, 3, 3))
self.assertEqual(mx.pad(a, (1,)).shape, (3, 3, 3))