Implement negative padding in conv with slicing (#907)

* Implement negative padding with slicing

* Update mlx/ops.cpp

Co-authored-by: Awni Hannun <awni@apple.com>

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Jagrit Digani 2024-03-26 14:59:19 -07:00 committed by GitHub
parent 925014b661
commit 240d10699c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2971,6 +2971,35 @@ array conv_general(
input_dilation = std::vector<int>(spatial_dims, input_dilation_int);
}
// Check for negative padding
bool has_neg_padding = false;
for (auto& pd : padding_lo) {
has_neg_padding = (pd < 0);
}
for (auto& pd : padding_hi) {
has_neg_padding = (pd < 0);
}
// Handle negative padding
if (has_neg_padding) {
std::vector<int> starts(in.ndim(), 0);
std::vector<int> stops = in.shape();
for (int i = 0; i < spatial_dims; i++) {
if (padding_lo[i] < 0) {
starts[i + 1] -= padding_lo[i];
padding_lo[i] = 0;
}
if (padding_hi[i] < 0) {
stops[i + 1] += padding_hi[i];
padding_hi[i] = 0;
}
}
in = slice(in, std::move(starts), std::move(stops), s);
}
// Get output shapes
std::vector<int> out_shape = conv_out_shape(
in.shape(),