From 240d10699c19699ddd7808bcb8032ab01c89f0e8 Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Tue, 26 Mar 2024 14:59:19 -0700 Subject: [PATCH] Implement negative padding in conv with slicing (#907) * Implement negative padding with slicing * Update mlx/ops.cpp Co-authored-by: Awni Hannun --------- Co-authored-by: Awni Hannun --- mlx/ops.cpp | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index ca3d60dd7..2dfe33a05 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -2971,6 +2971,35 @@ array conv_general( input_dilation = std::vector(spatial_dims, input_dilation_int); } + // Check for negative padding + bool has_neg_padding = false; + for (auto& pd : padding_lo) { + has_neg_padding = (pd < 0); + } + for (auto& pd : padding_hi) { + has_neg_padding = (pd < 0); + } + + // Handle negative padding + if (has_neg_padding) { + std::vector starts(in.ndim(), 0); + std::vector stops = in.shape(); + + for (int i = 0; i < spatial_dims; i++) { + if (padding_lo[i] < 0) { + starts[i + 1] -= padding_lo[i]; + padding_lo[i] = 0; + } + + if (padding_hi[i] < 0) { + stops[i + 1] += padding_hi[i]; + padding_hi[i] = 0; + } + } + + in = slice(in, std::move(starts), std::move(stops), s); + } + // Get output shapes std::vector out_shape = conv_out_shape( in.shape(),