Added output_padding parameters in conv_transpose (#2092)

This commit is contained in:
Param Thakkar
2025-04-23 21:56:33 +05:30
committed by GitHub
parent 3836445241
commit 600e87e03c
6 changed files with 366 additions and 14 deletions

View File

@@ -3769,6 +3769,7 @@ array conv_transpose_general(
std::vector<int> stride,
std::vector<int> padding,
std::vector<int> dilation,
std::vector<int> output_padding,
int groups,
StreamOrDevice s) {
std::vector<int> padding_lo(padding.size());
@@ -3782,7 +3783,8 @@ array conv_transpose_general(
int in_size = 1 + (conv_output_shape - 1);
int out_size = 1 + stride[i] * (input.shape(1 + i) - 1);
padding_hi[i] = in_size - out_size + padding[i];
padding_hi[i] = in_size - out_size + padding[i] +
output_padding[i]; // Adjust with output_padding
}
return conv_general(
@@ -3805,10 +3807,11 @@ array conv_transpose1d(
int stride /* = 1 */,
int padding /* = 0 */,
int dilation /* = 1 */,
int output_padding /* = 0 */,
int groups /* = 1 */,
StreamOrDevice s /* = {} */) {
return conv_transpose_general(
in_, wt_, {stride}, {padding}, {dilation}, groups, s);
in_, wt_, {stride}, {padding}, {dilation}, {output_padding}, groups, s);
}
/** 2D transposed convolution with a filter */
@@ -3818,6 +3821,7 @@ array conv_transpose2d(
const std::pair<int, int>& stride /* = {1, 1} */,
const std::pair<int, int>& padding /* = {0, 0} */,
const std::pair<int, int>& dilation /* = {1, 1} */,
const std::pair<int, int>& output_padding /* = {0, 0} */,
int groups /* = 1 */,
StreamOrDevice s /* = {} */) {
return conv_transpose_general(
@@ -3826,6 +3830,7 @@ array conv_transpose2d(
{stride.first, stride.second},
{padding.first, padding.second},
{dilation.first, dilation.second},
{output_padding.first, output_padding.second},
groups,
s);
}
@@ -3837,6 +3842,7 @@ array conv_transpose3d(
const std::tuple<int, int, int>& stride /* = {1, 1, 1} */,
const std::tuple<int, int, int>& padding /* = {0, 0, 0} */,
const std::tuple<int, int, int>& dilation /* = {1, 1, 1} */,
const std::tuple<int, int, int>& output_padding /* = {0, 0, 0} */,
int groups /* = 1 */,
StreamOrDevice s /* = {} */) {
return conv_transpose_general(
@@ -3845,6 +3851,9 @@ array conv_transpose3d(
{std::get<0>(stride), std::get<1>(stride), std::get<2>(stride)},
{std::get<0>(padding), std::get<1>(padding), std::get<2>(padding)},
{std::get<0>(dilation), std::get<1>(dilation), std::get<2>(dilation)},
{std::get<0>(output_padding),
std::get<1>(output_padding),
std::get<2>(output_padding)},
groups,
s);
}

View File

@@ -1291,6 +1291,7 @@ array conv_transpose1d(
int stride = 1,
int padding = 0,
int dilation = 1,
int output_padding = 0,
int groups = 1,
StreamOrDevice s = {});
@@ -1301,6 +1302,7 @@ array conv_transpose2d(
const std::pair<int, int>& stride = {1, 1},
const std::pair<int, int>& padding = {0, 0},
const std::pair<int, int>& dilation = {1, 1},
const std::pair<int, int>& output_padding = {0, 0},
int groups = 1,
StreamOrDevice s = {});
@@ -1311,6 +1313,7 @@ array conv_transpose3d(
const std::tuple<int, int, int>& stride = {1, 1, 1},
const std::tuple<int, int, int>& padding = {0, 0, 0},
const std::tuple<int, int, int>& dilation = {1, 1, 1},
const std::tuple<int, int, int>& output_padding = {0, 0, 0},
int groups = 1,
StreamOrDevice s = {});