mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 02:38:09 +08:00
Add "edge" mode to mx.pad (#1309)
* Add edge padding mode * fix pad in pooling * string arg instead of enum
This commit is contained in:
76
mlx/ops.cpp
76
mlx/ops.cpp
@@ -1000,6 +1000,51 @@ array tile(
|
||||
return reshape(x, final_shape, s);
|
||||
}
|
||||
|
||||
array edge_pad(
|
||||
const array& a,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<int>& low_pad_size,
|
||||
const std::vector<int>& high_pad_size,
|
||||
const std::vector<int>& out_shape,
|
||||
StreamOrDevice s /* = {}*/) {
|
||||
array out = zeros(out_shape, a.dtype(), s);
|
||||
auto stops = a.shape();
|
||||
for (int i = 0; i < stops.size(); i++) {
|
||||
stops[i] += low_pad_size[i];
|
||||
}
|
||||
// Copy over values from the unpadded array
|
||||
array padded = slice_update(out, a, low_pad_size, stops, s);
|
||||
|
||||
for (int axis = 0; axis < a.ndim(); axis++) {
|
||||
if (low_pad_size[axis] > 0) {
|
||||
std::vector<int> starts(a.ndim(), 0);
|
||||
starts[axis] = low_pad_size[axis];
|
||||
auto stops = out.shape();
|
||||
stops[axis] = low_pad_size[axis] + 1;
|
||||
// Fetch edge values
|
||||
array edge_value = slice(padded, starts, stops, s);
|
||||
|
||||
starts[axis] = 0;
|
||||
stops[axis] = low_pad_size[axis];
|
||||
// Update edge values in the padded array
|
||||
padded = slice_update(padded, edge_value, starts, stops, s);
|
||||
}
|
||||
|
||||
if (high_pad_size[axis] > 0) {
|
||||
std::vector<int> starts(a.ndim(), 0);
|
||||
starts[axis] = -high_pad_size[axis] - 1;
|
||||
auto stops = out.shape();
|
||||
stops[axis] = -high_pad_size[axis];
|
||||
array edge_value = slice(padded, starts, stops, s);
|
||||
|
||||
starts[axis] = -high_pad_size[axis];
|
||||
stops[axis] = out.shape(axis);
|
||||
padded = slice_update(padded, edge_value, starts, stops, s);
|
||||
}
|
||||
}
|
||||
return padded;
|
||||
}
|
||||
|
||||
/** Pad an array with a constant value */
|
||||
array pad(
|
||||
const array& a,
|
||||
@@ -1007,6 +1052,7 @@ array pad(
|
||||
const std::vector<int>& low_pad_size,
|
||||
const std::vector<int>& high_pad_size,
|
||||
const array& pad_value /*= array(0)*/,
|
||||
const std::string mode /*= "constant"*/,
|
||||
StreamOrDevice s /* = {}*/) {
|
||||
if (axes.size() != low_pad_size.size() ||
|
||||
axes.size() != high_pad_size.size()) {
|
||||
@@ -1038,11 +1084,19 @@ array pad(
|
||||
out_shape[ax] += low_pad_size[i] + high_pad_size[i];
|
||||
}
|
||||
|
||||
return array(
|
||||
out_shape,
|
||||
a.dtype(),
|
||||
std::make_shared<Pad>(to_stream(s), axes, low_pad_size, high_pad_size),
|
||||
{a, astype(pad_value, a.dtype(), s)});
|
||||
if (mode == "constant") {
|
||||
return array(
|
||||
out_shape,
|
||||
a.dtype(),
|
||||
std::make_shared<Pad>(to_stream(s), axes, low_pad_size, high_pad_size),
|
||||
{a, astype(pad_value, a.dtype(), s)});
|
||||
} else if (mode == "edge") {
|
||||
return edge_pad(a, axes, low_pad_size, high_pad_size, out_shape, s);
|
||||
} else {
|
||||
std::ostringstream msg;
|
||||
msg << "Invalid padding mode (" << mode << ") passed to pad";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
/** Pad an array with a constant value along all axes */
|
||||
@@ -1050,6 +1104,7 @@ array pad(
|
||||
const array& a,
|
||||
const std::vector<std::pair<int, int>>& pad_width,
|
||||
const array& pad_value /*= array(0)*/,
|
||||
const std::string mode /*= "constant"*/,
|
||||
StreamOrDevice s /*= {}*/) {
|
||||
std::vector<int> axes(a.ndim(), 0);
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
@@ -1062,27 +1117,34 @@ array pad(
|
||||
highs.push_back(pads.second);
|
||||
}
|
||||
|
||||
return pad(a, axes, lows, highs, pad_value, s);
|
||||
return pad(a, axes, lows, highs, pad_value, mode, s);
|
||||
}
|
||||
|
||||
array pad(
|
||||
const array& a,
|
||||
const std::pair<int, int>& pad_width,
|
||||
const array& pad_value /*= array(0)*/,
|
||||
const std::string mode /*= "constant"*/,
|
||||
StreamOrDevice s /*= {}*/) {
|
||||
return pad(
|
||||
a, std::vector<std::pair<int, int>>(a.ndim(), pad_width), pad_value, s);
|
||||
a,
|
||||
std::vector<std::pair<int, int>>(a.ndim(), pad_width),
|
||||
pad_value,
|
||||
mode,
|
||||
s);
|
||||
}
|
||||
|
||||
array pad(
|
||||
const array& a,
|
||||
int pad_width,
|
||||
const array& pad_value /*= array(0)*/,
|
||||
const std::string mode /*= "constant"*/,
|
||||
StreamOrDevice s /*= {}*/) {
|
||||
return pad(
|
||||
a,
|
||||
std::vector<std::pair<int, int>>(a.ndim(), {pad_width, pad_width}),
|
||||
pad_value,
|
||||
mode,
|
||||
s);
|
||||
}
|
||||
|
||||
|
@@ -263,6 +263,7 @@ array pad(
|
||||
const std::vector<int>& low_pad_size,
|
||||
const std::vector<int>& high_pad_size,
|
||||
const array& pad_value = array(0),
|
||||
const std::string mode = "constant",
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Pad an array with a constant value along all axes */
|
||||
@@ -270,16 +271,19 @@ array pad(
|
||||
const array& a,
|
||||
const std::vector<std::pair<int, int>>& pad_width,
|
||||
const array& pad_value = array(0),
|
||||
const std::string mode = "constant",
|
||||
StreamOrDevice s = {});
|
||||
array pad(
|
||||
const array& a,
|
||||
const std::pair<int, int>& pad_width,
|
||||
const array& pad_value = array(0),
|
||||
const std::string mode = "constant",
|
||||
StreamOrDevice s = {});
|
||||
array pad(
|
||||
const array& a,
|
||||
int pad_width,
|
||||
const array& pad_value = array(0),
|
||||
const std::string mode = "constant",
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Permutes the dimensions in reverse order. */
|
||||
|
@@ -855,8 +855,8 @@ array conv_weight_backward_patches(
|
||||
// Pad input
|
||||
std::vector<int> padded_axes(in.ndim() - 2, 0);
|
||||
std::iota(padded_axes.begin(), padded_axes.end(), 1);
|
||||
auto in_padded =
|
||||
pad(in, padded_axes, padding, padding, array(0, in.dtype()), s);
|
||||
auto in_padded = pad(
|
||||
in, padded_axes, padding, padding, array(0, in.dtype()), "constant", s);
|
||||
|
||||
// Resolve strided patches
|
||||
|
||||
@@ -2289,6 +2289,7 @@ std::vector<array> Pad::jvp(
|
||||
low_pad_size_,
|
||||
high_pad_size_,
|
||||
array(0, tangents[0].dtype()),
|
||||
"constant",
|
||||
stream())};
|
||||
}
|
||||
|
||||
@@ -3740,7 +3741,8 @@ std::vector<array> BlockMaskedMM::vjp(
|
||||
|
||||
// Pad if needed
|
||||
if ((align_Y != 0) || (align_X != 0)) {
|
||||
r = pad(r, {-2, -1}, {0, 0}, {align_Y, align_X}, zero, stream());
|
||||
r = pad(
|
||||
r, {-2, -1}, {0, 0}, {align_Y, align_X}, zero, "constant", stream());
|
||||
}
|
||||
|
||||
// Reshape
|
||||
|
Reference in New Issue
Block a user