mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-26 10:41:14 +08:00
Add "edge" mode to mx.pad (#1309)
* Add edge padding mode * fix pad in pooling * string arg instead of enum
This commit is contained in:
parent
8c9f0278b9
commit
635ccd9e25
76
mlx/ops.cpp
76
mlx/ops.cpp
@ -1000,6 +1000,51 @@ array tile(
|
|||||||
return reshape(x, final_shape, s);
|
return reshape(x, final_shape, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array edge_pad(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<int>& axes,
|
||||||
|
const std::vector<int>& low_pad_size,
|
||||||
|
const std::vector<int>& high_pad_size,
|
||||||
|
const std::vector<int>& out_shape,
|
||||||
|
StreamOrDevice s /* = {}*/) {
|
||||||
|
array out = zeros(out_shape, a.dtype(), s);
|
||||||
|
auto stops = a.shape();
|
||||||
|
for (int i = 0; i < stops.size(); i++) {
|
||||||
|
stops[i] += low_pad_size[i];
|
||||||
|
}
|
||||||
|
// Copy over values from the unpadded array
|
||||||
|
array padded = slice_update(out, a, low_pad_size, stops, s);
|
||||||
|
|
||||||
|
for (int axis = 0; axis < a.ndim(); axis++) {
|
||||||
|
if (low_pad_size[axis] > 0) {
|
||||||
|
std::vector<int> starts(a.ndim(), 0);
|
||||||
|
starts[axis] = low_pad_size[axis];
|
||||||
|
auto stops = out.shape();
|
||||||
|
stops[axis] = low_pad_size[axis] + 1;
|
||||||
|
// Fetch edge values
|
||||||
|
array edge_value = slice(padded, starts, stops, s);
|
||||||
|
|
||||||
|
starts[axis] = 0;
|
||||||
|
stops[axis] = low_pad_size[axis];
|
||||||
|
// Update edge values in the padded array
|
||||||
|
padded = slice_update(padded, edge_value, starts, stops, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (high_pad_size[axis] > 0) {
|
||||||
|
std::vector<int> starts(a.ndim(), 0);
|
||||||
|
starts[axis] = -high_pad_size[axis] - 1;
|
||||||
|
auto stops = out.shape();
|
||||||
|
stops[axis] = -high_pad_size[axis];
|
||||||
|
array edge_value = slice(padded, starts, stops, s);
|
||||||
|
|
||||||
|
starts[axis] = -high_pad_size[axis];
|
||||||
|
stops[axis] = out.shape(axis);
|
||||||
|
padded = slice_update(padded, edge_value, starts, stops, s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return padded;
|
||||||
|
}
|
||||||
|
|
||||||
/** Pad an array with a constant value */
|
/** Pad an array with a constant value */
|
||||||
array pad(
|
array pad(
|
||||||
const array& a,
|
const array& a,
|
||||||
@ -1007,6 +1052,7 @@ array pad(
|
|||||||
const std::vector<int>& low_pad_size,
|
const std::vector<int>& low_pad_size,
|
||||||
const std::vector<int>& high_pad_size,
|
const std::vector<int>& high_pad_size,
|
||||||
const array& pad_value /*= array(0)*/,
|
const array& pad_value /*= array(0)*/,
|
||||||
|
const std::string mode /*= "constant"*/,
|
||||||
StreamOrDevice s /* = {}*/) {
|
StreamOrDevice s /* = {}*/) {
|
||||||
if (axes.size() != low_pad_size.size() ||
|
if (axes.size() != low_pad_size.size() ||
|
||||||
axes.size() != high_pad_size.size()) {
|
axes.size() != high_pad_size.size()) {
|
||||||
@ -1038,11 +1084,19 @@ array pad(
|
|||||||
out_shape[ax] += low_pad_size[i] + high_pad_size[i];
|
out_shape[ax] += low_pad_size[i] + high_pad_size[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
return array(
|
if (mode == "constant") {
|
||||||
out_shape,
|
return array(
|
||||||
a.dtype(),
|
out_shape,
|
||||||
std::make_shared<Pad>(to_stream(s), axes, low_pad_size, high_pad_size),
|
a.dtype(),
|
||||||
{a, astype(pad_value, a.dtype(), s)});
|
std::make_shared<Pad>(to_stream(s), axes, low_pad_size, high_pad_size),
|
||||||
|
{a, astype(pad_value, a.dtype(), s)});
|
||||||
|
} else if (mode == "edge") {
|
||||||
|
return edge_pad(a, axes, low_pad_size, high_pad_size, out_shape, s);
|
||||||
|
} else {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "Invalid padding mode (" << mode << ") passed to pad";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Pad an array with a constant value along all axes */
|
/** Pad an array with a constant value along all axes */
|
||||||
@ -1050,6 +1104,7 @@ array pad(
|
|||||||
const array& a,
|
const array& a,
|
||||||
const std::vector<std::pair<int, int>>& pad_width,
|
const std::vector<std::pair<int, int>>& pad_width,
|
||||||
const array& pad_value /*= array(0)*/,
|
const array& pad_value /*= array(0)*/,
|
||||||
|
const std::string mode /*= "constant"*/,
|
||||||
StreamOrDevice s /*= {}*/) {
|
StreamOrDevice s /*= {}*/) {
|
||||||
std::vector<int> axes(a.ndim(), 0);
|
std::vector<int> axes(a.ndim(), 0);
|
||||||
std::iota(axes.begin(), axes.end(), 0);
|
std::iota(axes.begin(), axes.end(), 0);
|
||||||
@ -1062,27 +1117,34 @@ array pad(
|
|||||||
highs.push_back(pads.second);
|
highs.push_back(pads.second);
|
||||||
}
|
}
|
||||||
|
|
||||||
return pad(a, axes, lows, highs, pad_value, s);
|
return pad(a, axes, lows, highs, pad_value, mode, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
array pad(
|
array pad(
|
||||||
const array& a,
|
const array& a,
|
||||||
const std::pair<int, int>& pad_width,
|
const std::pair<int, int>& pad_width,
|
||||||
const array& pad_value /*= array(0)*/,
|
const array& pad_value /*= array(0)*/,
|
||||||
|
const std::string mode /*= "constant"*/,
|
||||||
StreamOrDevice s /*= {}*/) {
|
StreamOrDevice s /*= {}*/) {
|
||||||
return pad(
|
return pad(
|
||||||
a, std::vector<std::pair<int, int>>(a.ndim(), pad_width), pad_value, s);
|
a,
|
||||||
|
std::vector<std::pair<int, int>>(a.ndim(), pad_width),
|
||||||
|
pad_value,
|
||||||
|
mode,
|
||||||
|
s);
|
||||||
}
|
}
|
||||||
|
|
||||||
array pad(
|
array pad(
|
||||||
const array& a,
|
const array& a,
|
||||||
int pad_width,
|
int pad_width,
|
||||||
const array& pad_value /*= array(0)*/,
|
const array& pad_value /*= array(0)*/,
|
||||||
|
const std::string mode /*= "constant"*/,
|
||||||
StreamOrDevice s /*= {}*/) {
|
StreamOrDevice s /*= {}*/) {
|
||||||
return pad(
|
return pad(
|
||||||
a,
|
a,
|
||||||
std::vector<std::pair<int, int>>(a.ndim(), {pad_width, pad_width}),
|
std::vector<std::pair<int, int>>(a.ndim(), {pad_width, pad_width}),
|
||||||
pad_value,
|
pad_value,
|
||||||
|
mode,
|
||||||
s);
|
s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -263,6 +263,7 @@ array pad(
|
|||||||
const std::vector<int>& low_pad_size,
|
const std::vector<int>& low_pad_size,
|
||||||
const std::vector<int>& high_pad_size,
|
const std::vector<int>& high_pad_size,
|
||||||
const array& pad_value = array(0),
|
const array& pad_value = array(0),
|
||||||
|
const std::string mode = "constant",
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Pad an array with a constant value along all axes */
|
/** Pad an array with a constant value along all axes */
|
||||||
@ -270,16 +271,19 @@ array pad(
|
|||||||
const array& a,
|
const array& a,
|
||||||
const std::vector<std::pair<int, int>>& pad_width,
|
const std::vector<std::pair<int, int>>& pad_width,
|
||||||
const array& pad_value = array(0),
|
const array& pad_value = array(0),
|
||||||
|
const std::string mode = "constant",
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
array pad(
|
array pad(
|
||||||
const array& a,
|
const array& a,
|
||||||
const std::pair<int, int>& pad_width,
|
const std::pair<int, int>& pad_width,
|
||||||
const array& pad_value = array(0),
|
const array& pad_value = array(0),
|
||||||
|
const std::string mode = "constant",
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
array pad(
|
array pad(
|
||||||
const array& a,
|
const array& a,
|
||||||
int pad_width,
|
int pad_width,
|
||||||
const array& pad_value = array(0),
|
const array& pad_value = array(0),
|
||||||
|
const std::string mode = "constant",
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Permutes the dimensions in reverse order. */
|
/** Permutes the dimensions in reverse order. */
|
||||||
|
@ -855,8 +855,8 @@ array conv_weight_backward_patches(
|
|||||||
// Pad input
|
// Pad input
|
||||||
std::vector<int> padded_axes(in.ndim() - 2, 0);
|
std::vector<int> padded_axes(in.ndim() - 2, 0);
|
||||||
std::iota(padded_axes.begin(), padded_axes.end(), 1);
|
std::iota(padded_axes.begin(), padded_axes.end(), 1);
|
||||||
auto in_padded =
|
auto in_padded = pad(
|
||||||
pad(in, padded_axes, padding, padding, array(0, in.dtype()), s);
|
in, padded_axes, padding, padding, array(0, in.dtype()), "constant", s);
|
||||||
|
|
||||||
// Resolve strided patches
|
// Resolve strided patches
|
||||||
|
|
||||||
@ -2289,6 +2289,7 @@ std::vector<array> Pad::jvp(
|
|||||||
low_pad_size_,
|
low_pad_size_,
|
||||||
high_pad_size_,
|
high_pad_size_,
|
||||||
array(0, tangents[0].dtype()),
|
array(0, tangents[0].dtype()),
|
||||||
|
"constant",
|
||||||
stream())};
|
stream())};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3740,7 +3741,8 @@ std::vector<array> BlockMaskedMM::vjp(
|
|||||||
|
|
||||||
// Pad if needed
|
// Pad if needed
|
||||||
if ((align_Y != 0) || (align_X != 0)) {
|
if ((align_Y != 0) || (align_X != 0)) {
|
||||||
r = pad(r, {-2, -1}, {0, 0}, {align_Y, align_X}, zero, stream());
|
r = pad(
|
||||||
|
r, {-2, -1}, {0, 0}, {align_Y, align_X}, zero, "constant", stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reshape
|
// Reshape
|
||||||
|
@ -101,7 +101,11 @@ class _Pool(Module):
|
|||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
if any(p[0] > 0 for p in self._padding):
|
if any(p[0] > 0 for p in self._padding):
|
||||||
x = mx.pad(x, [(0, 0)] + self._padding + [(0, 0)], self._padding_value)
|
x = mx.pad(
|
||||||
|
x,
|
||||||
|
[(0, 0)] + self._padding + [(0, 0)],
|
||||||
|
constant_values=self._padding_value,
|
||||||
|
)
|
||||||
x = _sliding_windows(x, self._kernel_size, self._stride)
|
x = _sliding_windows(x, self._kernel_size, self._stride)
|
||||||
return self._pooling_function(x, self._axes)
|
return self._pooling_function(x, self._axes)
|
||||||
|
|
||||||
|
@ -2843,30 +2843,32 @@ void init_ops(nb::module_& m) {
|
|||||||
std::tuple<int>,
|
std::tuple<int>,
|
||||||
std::pair<int, int>,
|
std::pair<int, int>,
|
||||||
std::vector<std::pair<int, int>>>& pad_width,
|
std::vector<std::pair<int, int>>>& pad_width,
|
||||||
|
const std::string mode,
|
||||||
const ScalarOrArray& constant_value,
|
const ScalarOrArray& constant_value,
|
||||||
StreamOrDevice s) {
|
StreamOrDevice s) {
|
||||||
if (auto pv = std::get_if<int>(&pad_width); pv) {
|
if (auto pv = std::get_if<int>(&pad_width); pv) {
|
||||||
return pad(a, *pv, to_array(constant_value), s);
|
return pad(a, *pv, to_array(constant_value), mode, s);
|
||||||
} else if (auto pv = std::get_if<std::tuple<int>>(&pad_width); pv) {
|
} else if (auto pv = std::get_if<std::tuple<int>>(&pad_width); pv) {
|
||||||
return pad(a, std::get<0>(*pv), to_array(constant_value), s);
|
return pad(a, std::get<0>(*pv), to_array(constant_value), mode, s);
|
||||||
} else if (auto pv = std::get_if<std::pair<int, int>>(&pad_width); pv) {
|
} else if (auto pv = std::get_if<std::pair<int, int>>(&pad_width); pv) {
|
||||||
return pad(a, *pv, to_array(constant_value), s);
|
return pad(a, *pv, to_array(constant_value), mode, s);
|
||||||
} else {
|
} else {
|
||||||
auto v = std::get<std::vector<std::pair<int, int>>>(pad_width);
|
auto v = std::get<std::vector<std::pair<int, int>>>(pad_width);
|
||||||
if (v.size() == 1) {
|
if (v.size() == 1) {
|
||||||
return pad(a, v[0], to_array(constant_value), s);
|
return pad(a, v[0], to_array(constant_value), mode, s);
|
||||||
} else {
|
} else {
|
||||||
return pad(a, v, to_array(constant_value), s);
|
return pad(a, v, to_array(constant_value), mode, s);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
"pad_width"_a,
|
"pad_width"_a,
|
||||||
|
"mode"_a = "constant",
|
||||||
"constant_values"_a = 0,
|
"constant_values"_a = 0,
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
"def pad(a: array, pad_width: Union[int, Tuple[int], Tuple[int, int], List[Tuple[int, int]]], constant_values: Union[scalar, array] = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
|
"def pad(a: array, pad_width: Union[int, Tuple[int], Tuple[int, int], List[Tuple[int, int]]], mode: Literal['constant', 'edge'] = 'constant', constant_values: Union[scalar, array] = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Pad an array with a constant value
|
Pad an array with a constant value
|
||||||
|
|
||||||
@ -2878,6 +2880,9 @@ void init_ops(nb::module_& m) {
|
|||||||
of integers is passed then ``(before_i, after_i)`` are all the same.
|
of integers is passed then ``(before_i, after_i)`` are all the same.
|
||||||
If a single integer or tuple with a single integer is passed then
|
If a single integer or tuple with a single integer is passed then
|
||||||
all axes are extended by the same number on each side.
|
all axes are extended by the same number on each side.
|
||||||
|
mode: Padding mode. One of the following strings:
|
||||||
|
"constant" (default): Pads with a constant value.
|
||||||
|
"edge": Pads with the edge values of array.
|
||||||
constant_value (array or scalar, optional): Optional constant value
|
constant_value (array or scalar, optional): Optional constant value
|
||||||
to pad the edges of the array with.
|
to pad the edges of the array with.
|
||||||
|
|
||||||
@ -3155,7 +3160,8 @@ void init_ops(nb::module_& m) {
|
|||||||
} else { // Even sizes use asymmetric padding
|
} else { // Even sizes use asymmetric padding
|
||||||
int pad_l = wt.size() / 2;
|
int pad_l = wt.size() / 2;
|
||||||
int pad_r = std::max(0, pad_l - 1);
|
int pad_r = std::max(0, pad_l - 1);
|
||||||
in = pad(in, {{0, 0}, {pad_l, pad_r}, {0, 0}}, array(0), s);
|
in = pad(
|
||||||
|
in, {{0, 0}, {pad_l, pad_r}, {0, 0}}, array(0), "constant", s);
|
||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
|
@ -1623,6 +1623,12 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(list(b_npy.shape), list(b_mlx.shape))
|
self.assertEqual(list(b_npy.shape), list(b_mlx.shape))
|
||||||
self.assertTrue(np.allclose(b_npy, b_mlx, atol=1e-6))
|
self.assertTrue(np.allclose(b_npy, b_mlx, atol=1e-6))
|
||||||
|
|
||||||
|
b_npy = np.pad(a_npy, pw, mode="edge")
|
||||||
|
b_mlx = mx.pad(a_mlx, pw, mode="edge")
|
||||||
|
|
||||||
|
self.assertEqual(list(b_npy.shape), list(b_mlx.shape))
|
||||||
|
self.assertTrue(np.allclose(b_npy, b_mlx, atol=1e-6))
|
||||||
|
|
||||||
a = mx.zeros((1, 1, 1))
|
a = mx.zeros((1, 1, 1))
|
||||||
self.assertEqual(mx.pad(a, 1).shape, (3, 3, 3))
|
self.assertEqual(mx.pad(a, 1).shape, (3, 3, 3))
|
||||||
self.assertEqual(mx.pad(a, (1,)).shape, (3, 3, 3))
|
self.assertEqual(mx.pad(a, (1,)).shape, (3, 3, 3))
|
||||||
|
Loading…
Reference in New Issue
Block a user