mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-18 00:16:39 +08:00
Implement negative padding in conv with slicing (#907)
* Implement negative padding with slicing * Update mlx/ops.cpp Co-authored-by: Awni Hannun <awni@apple.com> --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
925014b661
commit
240d10699c
29
mlx/ops.cpp
29
mlx/ops.cpp
@ -2971,6 +2971,35 @@ array conv_general(
|
|||||||
input_dilation = std::vector<int>(spatial_dims, input_dilation_int);
|
input_dilation = std::vector<int>(spatial_dims, input_dilation_int);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check for negative padding
|
||||||
|
bool has_neg_padding = false;
|
||||||
|
for (auto& pd : padding_lo) {
|
||||||
|
has_neg_padding = (pd < 0);
|
||||||
|
}
|
||||||
|
for (auto& pd : padding_hi) {
|
||||||
|
has_neg_padding = (pd < 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle negative padding
|
||||||
|
if (has_neg_padding) {
|
||||||
|
std::vector<int> starts(in.ndim(), 0);
|
||||||
|
std::vector<int> stops = in.shape();
|
||||||
|
|
||||||
|
for (int i = 0; i < spatial_dims; i++) {
|
||||||
|
if (padding_lo[i] < 0) {
|
||||||
|
starts[i + 1] -= padding_lo[i];
|
||||||
|
padding_lo[i] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (padding_hi[i] < 0) {
|
||||||
|
stops[i + 1] += padding_hi[i];
|
||||||
|
padding_hi[i] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
in = slice(in, std::move(starts), std::move(stops), s);
|
||||||
|
}
|
||||||
|
|
||||||
// Get output shapes
|
// Get output shapes
|
||||||
std::vector<int> out_shape = conv_out_shape(
|
std::vector<int> out_shape = conv_out_shape(
|
||||||
in.shape(),
|
in.shape(),
|
||||||
|
Loading…
Reference in New Issue
Block a user