Add "edge" mode to mx.pad (#1309)

* Add edge padding mode

* fix pad in pooling

* string arg instead of enum
This commit is contained in:
Alex Barron
2024-08-06 11:23:10 -07:00
committed by GitHub
parent 8c9f0278b9
commit 635ccd9e25
6 changed files with 102 additions and 18 deletions

View File

@@ -101,7 +101,11 @@ class _Pool(Module):
def __call__(self, x):
if any(p[0] > 0 for p in self._padding):
x = mx.pad(x, [(0, 0)] + self._padding + [(0, 0)], self._padding_value)
x = mx.pad(
x,
[(0, 0)] + self._padding + [(0, 0)],
constant_values=self._padding_value,
)
x = _sliding_windows(x, self._kernel_size, self._stride)
return self._pooling_function(x, self._axes)

View File

@@ -2843,30 +2843,32 @@ void init_ops(nb::module_& m) {
std::tuple<int>,
std::pair<int, int>,
std::vector<std::pair<int, int>>>& pad_width,
const std::string mode,
const ScalarOrArray& constant_value,
StreamOrDevice s) {
if (auto pv = std::get_if<int>(&pad_width); pv) {
return pad(a, *pv, to_array(constant_value), s);
return pad(a, *pv, to_array(constant_value), mode, s);
} else if (auto pv = std::get_if<std::tuple<int>>(&pad_width); pv) {
return pad(a, std::get<0>(*pv), to_array(constant_value), s);
return pad(a, std::get<0>(*pv), to_array(constant_value), mode, s);
} else if (auto pv = std::get_if<std::pair<int, int>>(&pad_width); pv) {
return pad(a, *pv, to_array(constant_value), s);
return pad(a, *pv, to_array(constant_value), mode, s);
} else {
auto v = std::get<std::vector<std::pair<int, int>>>(pad_width);
if (v.size() == 1) {
return pad(a, v[0], to_array(constant_value), s);
return pad(a, v[0], to_array(constant_value), mode, s);
} else {
return pad(a, v, to_array(constant_value), s);
return pad(a, v, to_array(constant_value), mode, s);
}
}
},
nb::arg(),
"pad_width"_a,
"mode"_a = "constant",
"constant_values"_a = 0,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def pad(a: array, pad_width: Union[int, Tuple[int], Tuple[int, int], List[Tuple[int, int]]], constant_values: Union[scalar, array] = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
"def pad(a: array, pad_width: Union[int, Tuple[int], Tuple[int, int], List[Tuple[int, int]]], mode: Literal['constant', 'edge'] = 'constant', constant_values: Union[scalar, array] = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Pad an array with a constant value
@@ -2878,6 +2880,9 @@ void init_ops(nb::module_& m) {
of integers is passed then ``(before_i, after_i)`` are all the same.
If a single integer or tuple with a single integer is passed then
all axes are extended by the same number on each side.
mode: Padding mode. One of the following strings:
"constant" (default): Pads with a constant value.
"edge": Pads with the edge values of array.
constant_value (array or scalar, optional): Optional constant value
to pad the edges of the array with.
@@ -3155,7 +3160,8 @@ void init_ops(nb::module_& m) {
} else { // Even sizes use asymmetric padding
int pad_l = wt.size() / 2;
int pad_r = std::max(0, pad_l - 1);
in = pad(in, {{0, 0}, {pad_l, pad_r}, {0, 0}}, array(0), s);
in = pad(
in, {{0, 0}, {pad_l, pad_r}, {0, 0}}, array(0), "constant", s);
}
} else {

View File

@@ -1623,6 +1623,12 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(list(b_npy.shape), list(b_mlx.shape))
self.assertTrue(np.allclose(b_npy, b_mlx, atol=1e-6))
b_npy = np.pad(a_npy, pw, mode="edge")
b_mlx = mx.pad(a_mlx, pw, mode="edge")
self.assertEqual(list(b_npy.shape), list(b_mlx.shape))
self.assertTrue(np.allclose(b_npy, b_mlx, atol=1e-6))
a = mx.zeros((1, 1, 1))
self.assertEqual(mx.pad(a, 1).shape, (3, 3, 3))
self.assertEqual(mx.pad(a, (1,)).shape, (3, 3, 3))