diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 54ac62fef3..c2aa4786f5 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3769,6 +3769,7 @@ array conv_transpose_general( std::vector stride, std::vector padding, std::vector dilation, + std::vector output_padding, int groups, StreamOrDevice s) { std::vector padding_lo(padding.size()); @@ -3782,7 +3783,8 @@ array conv_transpose_general( int in_size = 1 + (conv_output_shape - 1); int out_size = 1 + stride[i] * (input.shape(1 + i) - 1); - padding_hi[i] = in_size - out_size + padding[i]; + padding_hi[i] = in_size - out_size + padding[i] + + output_padding[i]; // Adjust with output_padding } return conv_general( @@ -3805,10 +3807,11 @@ array conv_transpose1d( int stride /* = 1 */, int padding /* = 0 */, int dilation /* = 1 */, + int output_padding /* = 0 */, int groups /* = 1 */, StreamOrDevice s /* = {} */) { return conv_transpose_general( - in_, wt_, {stride}, {padding}, {dilation}, groups, s); + in_, wt_, {stride}, {padding}, {dilation}, {output_padding}, groups, s); } /** 2D transposed convolution with a filter */ @@ -3818,6 +3821,7 @@ array conv_transpose2d( const std::pair& stride /* = {1, 1} */, const std::pair& padding /* = {0, 0} */, const std::pair& dilation /* = {1, 1} */, + const std::pair& output_padding /* = {0, 0} */, int groups /* = 1 */, StreamOrDevice s /* = {} */) { return conv_transpose_general( @@ -3826,6 +3830,7 @@ array conv_transpose2d( {stride.first, stride.second}, {padding.first, padding.second}, {dilation.first, dilation.second}, + {output_padding.first, output_padding.second}, groups, s); } @@ -3837,6 +3842,7 @@ array conv_transpose3d( const std::tuple& stride /* = {1, 1, 1} */, const std::tuple& padding /* = {0, 0, 0} */, const std::tuple& dilation /* = {1, 1, 1} */, + const std::tuple& output_padding /* = {0, 0, 0} */, int groups /* = 1 */, StreamOrDevice s /* = {} */) { return conv_transpose_general( @@ -3845,6 +3851,9 @@ array conv_transpose3d( {std::get<0>(stride), std::get<1>(stride), std::get<2>(stride)}, {std::get<0>(padding), std::get<1>(padding), std::get<2>(padding)}, {std::get<0>(dilation), std::get<1>(dilation), std::get<2>(dilation)}, + {std::get<0>(output_padding), + std::get<1>(output_padding), + std::get<2>(output_padding)}, groups, s); } diff --git a/python/src/ops.cpp b/python/src/ops.cpp index f98aa80aac..8c5cea9f95 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3609,6 +3609,7 @@ void init_ops(nb::module_& m) { "stride"_a = 1, "padding"_a = 0, "dilation"_a = 1, + "output_padding"_a = 0, "groups"_a = 1, nb::kw_only(), "stream"_a = nb::none(), @@ -3623,6 +3624,7 @@ void init_ops(nb::module_& m) { stride (int, optional): Kernel stride. Default: ``1``. padding (int, optional): Input padding. Default: ``0``. dilation (int, optional): Kernel dilation. Default: ``1``. + output_padding (int, optional): Output padding. Default: ``0``. groups (int, optional): Input feature groups. Default: ``1``. Returns: @@ -3635,11 +3637,13 @@ void init_ops(nb::module_& m) { const std::variant>& stride, const std::variant>& padding, const std::variant>& dilation, + const std::variant>& output_padding, int groups, mx::StreamOrDevice s) { std::pair stride_pair{1, 1}; std::pair padding_pair{0, 0}; std::pair dilation_pair{1, 1}; + std::pair output_padding_pair{0, 0}; if (auto pv = std::get_if(&stride); pv) { stride_pair = std::pair{*pv, *pv}; @@ -3659,14 +3663,28 @@ void init_ops(nb::module_& m) { dilation_pair = std::get>(dilation); } + if (auto pv = std::get_if(&output_padding); pv) { + output_padding_pair = std::pair{*pv, *pv}; + } else { + output_padding_pair = std::get>(output_padding); + } + return mx::conv_transpose2d( - input, weight, stride_pair, padding_pair, dilation_pair, groups, s); + input, + weight, + stride_pair, + padding_pair, + dilation_pair, + output_padding_pair, + groups, + s); }, nb::arg(), nb::arg(), "stride"_a = 1, "padding"_a = 0, "dilation"_a = 1, + "output_padding"_a = 0, "groups"_a = 1, nb::kw_only(), "stream"_a = nb::none(), @@ -3689,6 +3707,9 @@ void init_ops(nb::module_& m) { dilation (int or tuple(int), optional): :obj:`tuple` of size 2 with kernel dilation. All spatial dimensions get the same dilation if only one number is specified. Default: ``1`` + output_padding (int or tuple(int), optional): :obj:`tuple` of size 2 with + output padding. All spatial dimensions get the same output + padding if only one number is specified. Default: ``0``. groups (int, optional): input feature groups. Default: ``1``. Returns: @@ -3701,11 +3722,13 @@ void init_ops(nb::module_& m) { const std::variant>& stride, const std::variant>& padding, const std::variant>& dilation, + const std::variant>& output_padding, int groups, mx::StreamOrDevice s) { std::tuple stride_tuple{1, 1, 1}; std::tuple padding_tuple{0, 0, 0}; std::tuple dilation_tuple{1, 1, 1}; + std::tuple output_padding_tuple{0, 0, 0}; if (auto pv = std::get_if(&stride); pv) { stride_tuple = std::tuple{*pv, *pv, *pv}; @@ -3725,12 +3748,20 @@ void init_ops(nb::module_& m) { dilation_tuple = std::get>(dilation); } + if (auto pv = std::get_if(&output_padding); pv) { + output_padding_tuple = std::tuple{*pv, *pv, *pv}; + } else { + output_padding_tuple = + std::get>(output_padding); + } + return mx::conv_transpose3d( input, weight, stride_tuple, padding_tuple, dilation_tuple, + output_padding_tuple, groups, s); }, @@ -3739,6 +3770,7 @@ void init_ops(nb::module_& m) { "stride"_a = 1, "padding"_a = 0, "dilation"_a = 1, + "output_padding"_a = 0, "groups"_a = 1, nb::kw_only(), "stream"_a = nb::none(), @@ -3761,6 +3793,9 @@ void init_ops(nb::module_& m) { dilation (int or tuple(int), optional): :obj:`tuple` of size 3 with kernel dilation. All spatial dimensions get the same dilation if only one number is specified. Default: ``1`` + output_padding (int or tuple(int), optional): :obj:`tuple` of size 3 with + output padding. All spatial dimensions get the same output + padding if only one number is specified. Default: ``0``. groups (int, optional): input feature groups. Default: ``1``. Returns: @@ -3777,6 +3812,7 @@ void init_ops(nb::module_& m) { std::pair, std::vector>>& padding, const std::variant>& kernel_dilation, const std::variant>& input_dilation, + const std::variant>& output_padding, int groups, bool flip, mx::StreamOrDevice s) { @@ -3785,6 +3821,7 @@ void init_ops(nb::module_& m) { std::vector padding_hi_vec; std::vector kernel_dilation_vec; std::vector input_dilation_vec; + std::vector output_padding_vec; if (auto pv = std::get_if(&stride); pv) { stride_vec.push_back(*pv); @@ -3817,6 +3854,12 @@ void init_ops(nb::module_& m) { input_dilation_vec = std::get>(input_dilation); } + if (auto pv = std::get_if(&output_padding); pv) { + output_padding_vec.push_back(*pv); + } else { + output_padding_vec = std::get>(output_padding); + } + return mx::conv_general( /* array input = */ std::move(input), /* array weight = */ std::move(weight), @@ -3827,6 +3870,8 @@ void init_ops(nb::module_& m) { std::move(kernel_dilation_vec), /* std::vector input_dilation = */ std::move(input_dilation_vec), + /* std::vector output_padding = */ + std::move(output_padding_vec), /* int groups = */ groups, /* bool flip = */ flip, s); @@ -3837,6 +3882,7 @@ void init_ops(nb::module_& m) { "padding"_a = 0, "kernel_dilation"_a = 1, "input_dilation"_a = 1, + "output_padding"_a = 0, "groups"_a = 1, "flip"_a = false, nb::kw_only(), @@ -3861,6 +3907,9 @@ void init_ops(nb::module_& m) { input_dilation (int or list(int), optional): :obj:`list` with input dilation. All spatial dimensions get the same dilation if only one number is specified. Default: ``1`` + output_padding (int or list(int), optional): :obj:`list` with + output padding. All spatial dimensions get the same padding + if only one number is specified. Default: ``0`` groups (int, optional): Input feature groups. Default: ``1``. flip (bool, optional): Flip the order in which the spatial dimensions of the weights are processed. Performs the cross-correlation operator when diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index de0f3352cf..7866db19ca 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -3911,4 +3911,64 @@ TEST_CASE("test bitwise shift operations") { CHECK_EQ(right_shift_bool_result.dtype(), uint8); CHECK(array_equal(right_shift_bool_result, full({4}, 0, uint8)).item()); +} + +TEST_CASE("test conv_transpose1d with output_padding") { + auto in = array({1.0, 2.0, 3.0}, {1, 1, 3}); + auto wt = array({1.0, 1.0}, {1, 1, 2}); + int stride = 2; + int padding = 0; + int output_padding = 1; + + auto out = conv_transpose1d(in, wt, stride, padding, 1, output_padding, 1); + auto expected = array({1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 0.0}, {1, 1, 7}); + CHECK(array_equal(out, expected).item()); +} + +TEST_CASE("test conv_transpose2d with output_padding") { + auto in = array({1.0, 2.0, 3.0, 4.0}, {1, 1, 2, 2}); + auto wt = array({1.0, 1.0, 1.0, 1.0}, {1, 1, 2, 2}); + std::pair stride{2, 2}; + std::pair padding{0, 0}; + std::pair output_padding{1, 1}; + + auto out = + conv_transpose2d(in, wt, stride, padding, {1, 1}, output_padding, 1); + auto expected = array( + {1.0, + 1.0, + 2.0, + 2.0, + 0.0, + 3.0, + 3.0, + 4.0, + 4.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0}, + {1, 1, 5, 5}); + CHECK(array_equal(out, expected).item()); +} + +TEST_CASE("test conv_transpose3d with output_padding") { + auto in = array({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}, {1, 1, 2, 2, 2}); + auto wt = array({1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}, {1, 1, 2, 2, 2}); + std::tuple stride{2, 2, 2}; + std::tuple padding{0, 0, 0}; + std::tuple output_padding{1, 1, 1}; + + auto out = + conv_transpose3d(in, wt, stride, padding, {1, 1, 1}, output_padding, 1); + auto expected = array( + {1.0, 1.0, 2.0, 2.0, 0.0, 3.0, 3.0, 4.0, 4.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 5.0, 5.0, 6.0, 6.0, 0.0, 7.0, 7.0, 8.0, 8.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, + {1, 1, 5, 5, 5}); + CHECK(array_equal(out, expected).item()); } \ No newline at end of file