mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Added output_padding parameters in conv_transpose (#2092)
This commit is contained in:
13
mlx/ops.cpp
13
mlx/ops.cpp
@@ -3769,6 +3769,7 @@ array conv_transpose_general(
|
||||
std::vector<int> stride,
|
||||
std::vector<int> padding,
|
||||
std::vector<int> dilation,
|
||||
std::vector<int> output_padding,
|
||||
int groups,
|
||||
StreamOrDevice s) {
|
||||
std::vector<int> padding_lo(padding.size());
|
||||
@@ -3782,7 +3783,8 @@ array conv_transpose_general(
|
||||
|
||||
int in_size = 1 + (conv_output_shape - 1);
|
||||
int out_size = 1 + stride[i] * (input.shape(1 + i) - 1);
|
||||
padding_hi[i] = in_size - out_size + padding[i];
|
||||
padding_hi[i] = in_size - out_size + padding[i] +
|
||||
output_padding[i]; // Adjust with output_padding
|
||||
}
|
||||
|
||||
return conv_general(
|
||||
@@ -3805,10 +3807,11 @@ array conv_transpose1d(
|
||||
int stride /* = 1 */,
|
||||
int padding /* = 0 */,
|
||||
int dilation /* = 1 */,
|
||||
int output_padding /* = 0 */,
|
||||
int groups /* = 1 */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return conv_transpose_general(
|
||||
in_, wt_, {stride}, {padding}, {dilation}, groups, s);
|
||||
in_, wt_, {stride}, {padding}, {dilation}, {output_padding}, groups, s);
|
||||
}
|
||||
|
||||
/** 2D transposed convolution with a filter */
|
||||
@@ -3818,6 +3821,7 @@ array conv_transpose2d(
|
||||
const std::pair<int, int>& stride /* = {1, 1} */,
|
||||
const std::pair<int, int>& padding /* = {0, 0} */,
|
||||
const std::pair<int, int>& dilation /* = {1, 1} */,
|
||||
const std::pair<int, int>& output_padding /* = {0, 0} */,
|
||||
int groups /* = 1 */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return conv_transpose_general(
|
||||
@@ -3826,6 +3830,7 @@ array conv_transpose2d(
|
||||
{stride.first, stride.second},
|
||||
{padding.first, padding.second},
|
||||
{dilation.first, dilation.second},
|
||||
{output_padding.first, output_padding.second},
|
||||
groups,
|
||||
s);
|
||||
}
|
||||
@@ -3837,6 +3842,7 @@ array conv_transpose3d(
|
||||
const std::tuple<int, int, int>& stride /* = {1, 1, 1} */,
|
||||
const std::tuple<int, int, int>& padding /* = {0, 0, 0} */,
|
||||
const std::tuple<int, int, int>& dilation /* = {1, 1, 1} */,
|
||||
const std::tuple<int, int, int>& output_padding /* = {0, 0, 0} */,
|
||||
int groups /* = 1 */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return conv_transpose_general(
|
||||
@@ -3845,6 +3851,9 @@ array conv_transpose3d(
|
||||
{std::get<0>(stride), std::get<1>(stride), std::get<2>(stride)},
|
||||
{std::get<0>(padding), std::get<1>(padding), std::get<2>(padding)},
|
||||
{std::get<0>(dilation), std::get<1>(dilation), std::get<2>(dilation)},
|
||||
{std::get<0>(output_padding),
|
||||
std::get<1>(output_padding),
|
||||
std::get<2>(output_padding)},
|
||||
groups,
|
||||
s);
|
||||
}
|
||||
|
||||
@@ -1291,6 +1291,7 @@ array conv_transpose1d(
|
||||
int stride = 1,
|
||||
int padding = 0,
|
||||
int dilation = 1,
|
||||
int output_padding = 0,
|
||||
int groups = 1,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
@@ -1301,6 +1302,7 @@ array conv_transpose2d(
|
||||
const std::pair<int, int>& stride = {1, 1},
|
||||
const std::pair<int, int>& padding = {0, 0},
|
||||
const std::pair<int, int>& dilation = {1, 1},
|
||||
const std::pair<int, int>& output_padding = {0, 0},
|
||||
int groups = 1,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
@@ -1311,6 +1313,7 @@ array conv_transpose3d(
|
||||
const std::tuple<int, int, int>& stride = {1, 1, 1},
|
||||
const std::tuple<int, int, int>& padding = {0, 0, 0},
|
||||
const std::tuple<int, int, int>& dilation = {1, 1, 1},
|
||||
const std::tuple<int, int, int>& output_padding = {0, 0, 0},
|
||||
int groups = 1,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
|
||||
Reference in New Issue
Block a user