mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-29 13:01:14 +08:00
Implement negative padding in conv with slicing (#907)
* Implement negative padding with slicing * Update mlx/ops.cpp Co-authored-by: Awni Hannun <awni@apple.com> --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
925014b661
commit
240d10699c
29
mlx/ops.cpp
29
mlx/ops.cpp
@ -2971,6 +2971,35 @@ array conv_general(
|
||||
input_dilation = std::vector<int>(spatial_dims, input_dilation_int);
|
||||
}
|
||||
|
||||
// Check for negative padding
|
||||
bool has_neg_padding = false;
|
||||
for (auto& pd : padding_lo) {
|
||||
has_neg_padding = (pd < 0);
|
||||
}
|
||||
for (auto& pd : padding_hi) {
|
||||
has_neg_padding = (pd < 0);
|
||||
}
|
||||
|
||||
// Handle negative padding
|
||||
if (has_neg_padding) {
|
||||
std::vector<int> starts(in.ndim(), 0);
|
||||
std::vector<int> stops = in.shape();
|
||||
|
||||
for (int i = 0; i < spatial_dims; i++) {
|
||||
if (padding_lo[i] < 0) {
|
||||
starts[i + 1] -= padding_lo[i];
|
||||
padding_lo[i] = 0;
|
||||
}
|
||||
|
||||
if (padding_hi[i] < 0) {
|
||||
stops[i + 1] += padding_hi[i];
|
||||
padding_hi[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
in = slice(in, std::move(starts), std::move(stops), s);
|
||||
}
|
||||
|
||||
// Get output shapes
|
||||
std::vector<int> out_shape = conv_out_shape(
|
||||
in.shape(),
|
||||
|
Loading…
Reference in New Issue
Block a user