Add "edge" mode to mx.pad (#1309)

* Add edge padding mode

* fix pad in pooling

* string arg instead of enum
This commit is contained in:
Alex Barron
2024-08-06 11:23:10 -07:00
committed by GitHub
parent 8c9f0278b9
commit 635ccd9e25
6 changed files with 102 additions and 18 deletions

View File

@@ -101,7 +101,11 @@ class _Pool(Module):
def __call__(self, x):
if any(p[0] > 0 for p in self._padding):
x = mx.pad(x, [(0, 0)] + self._padding + [(0, 0)], self._padding_value)
x = mx.pad(
x,
[(0, 0)] + self._padding + [(0, 0)],
constant_values=self._padding_value,
)
x = _sliding_windows(x, self._kernel_size, self._stride)
return self._pooling_function(x, self._axes)