diff --git a/mlx/ops.cpp b/mlx/ops.cpp index ca3d60dd7..2dfe33a05 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -2971,6 +2971,35 @@ array conv_general( input_dilation = std::vector(spatial_dims, input_dilation_int); } + // Check for negative padding + bool has_neg_padding = false; + for (auto& pd : padding_lo) { + has_neg_padding = (pd < 0); + } + for (auto& pd : padding_hi) { + has_neg_padding = (pd < 0); + } + + // Handle negative padding + if (has_neg_padding) { + std::vector starts(in.ndim(), 0); + std::vector stops = in.shape(); + + for (int i = 0; i < spatial_dims; i++) { + if (padding_lo[i] < 0) { + starts[i + 1] -= padding_lo[i]; + padding_lo[i] = 0; + } + + if (padding_hi[i] < 0) { + stops[i + 1] += padding_hi[i]; + padding_hi[i] = 0; + } + } + + in = slice(in, std::move(starts), std::move(stops), s); + } + // Get output shapes std::vector out_shape = conv_out_shape( in.shape(),