mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-12 04:06:39 +08:00
Added conv_transpose fixes
This commit is contained in:
parent
5f04c0f818
commit
59b934dcf9
13
mlx/ops.cpp
13
mlx/ops.cpp
@ -3769,6 +3769,7 @@ array conv_transpose_general(
|
|||||||
std::vector<int> stride,
|
std::vector<int> stride,
|
||||||
std::vector<int> padding,
|
std::vector<int> padding,
|
||||||
std::vector<int> dilation,
|
std::vector<int> dilation,
|
||||||
|
std::vector<int> output_padding,
|
||||||
int groups,
|
int groups,
|
||||||
StreamOrDevice s) {
|
StreamOrDevice s) {
|
||||||
std::vector<int> padding_lo(padding.size());
|
std::vector<int> padding_lo(padding.size());
|
||||||
@ -3782,7 +3783,8 @@ array conv_transpose_general(
|
|||||||
|
|
||||||
int in_size = 1 + (conv_output_shape - 1);
|
int in_size = 1 + (conv_output_shape - 1);
|
||||||
int out_size = 1 + stride[i] * (input.shape(1 + i) - 1);
|
int out_size = 1 + stride[i] * (input.shape(1 + i) - 1);
|
||||||
padding_hi[i] = in_size - out_size + padding[i];
|
padding_hi[i] = in_size - out_size + padding[i] +
|
||||||
|
output_padding[i]; // Adjust with output_padding
|
||||||
}
|
}
|
||||||
|
|
||||||
return conv_general(
|
return conv_general(
|
||||||
@ -3805,10 +3807,11 @@ array conv_transpose1d(
|
|||||||
int stride /* = 1 */,
|
int stride /* = 1 */,
|
||||||
int padding /* = 0 */,
|
int padding /* = 0 */,
|
||||||
int dilation /* = 1 */,
|
int dilation /* = 1 */,
|
||||||
|
int output_padding /* = 0 */,
|
||||||
int groups /* = 1 */,
|
int groups /* = 1 */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
return conv_transpose_general(
|
return conv_transpose_general(
|
||||||
in_, wt_, {stride}, {padding}, {dilation}, groups, s);
|
in_, wt_, {stride}, {padding}, {dilation}, {output_padding}, groups, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** 2D transposed convolution with a filter */
|
/** 2D transposed convolution with a filter */
|
||||||
@ -3818,6 +3821,7 @@ array conv_transpose2d(
|
|||||||
const std::pair<int, int>& stride /* = {1, 1} */,
|
const std::pair<int, int>& stride /* = {1, 1} */,
|
||||||
const std::pair<int, int>& padding /* = {0, 0} */,
|
const std::pair<int, int>& padding /* = {0, 0} */,
|
||||||
const std::pair<int, int>& dilation /* = {1, 1} */,
|
const std::pair<int, int>& dilation /* = {1, 1} */,
|
||||||
|
const std::pair<int, int>& output_padding /* = {0, 0} */,
|
||||||
int groups /* = 1 */,
|
int groups /* = 1 */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
return conv_transpose_general(
|
return conv_transpose_general(
|
||||||
@ -3826,6 +3830,7 @@ array conv_transpose2d(
|
|||||||
{stride.first, stride.second},
|
{stride.first, stride.second},
|
||||||
{padding.first, padding.second},
|
{padding.first, padding.second},
|
||||||
{dilation.first, dilation.second},
|
{dilation.first, dilation.second},
|
||||||
|
{output_padding.first, output_padding.second},
|
||||||
groups,
|
groups,
|
||||||
s);
|
s);
|
||||||
}
|
}
|
||||||
@ -3837,6 +3842,7 @@ array conv_transpose3d(
|
|||||||
const std::tuple<int, int, int>& stride /* = {1, 1, 1} */,
|
const std::tuple<int, int, int>& stride /* = {1, 1, 1} */,
|
||||||
const std::tuple<int, int, int>& padding /* = {0, 0, 0} */,
|
const std::tuple<int, int, int>& padding /* = {0, 0, 0} */,
|
||||||
const std::tuple<int, int, int>& dilation /* = {1, 1, 1} */,
|
const std::tuple<int, int, int>& dilation /* = {1, 1, 1} */,
|
||||||
|
const std::tuple<int, int, int>& output_padding /* = {0, 0, 0} */,
|
||||||
int groups /* = 1 */,
|
int groups /* = 1 */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
return conv_transpose_general(
|
return conv_transpose_general(
|
||||||
@ -3845,6 +3851,9 @@ array conv_transpose3d(
|
|||||||
{std::get<0>(stride), std::get<1>(stride), std::get<2>(stride)},
|
{std::get<0>(stride), std::get<1>(stride), std::get<2>(stride)},
|
||||||
{std::get<0>(padding), std::get<1>(padding), std::get<2>(padding)},
|
{std::get<0>(padding), std::get<1>(padding), std::get<2>(padding)},
|
||||||
{std::get<0>(dilation), std::get<1>(dilation), std::get<2>(dilation)},
|
{std::get<0>(dilation), std::get<1>(dilation), std::get<2>(dilation)},
|
||||||
|
{std::get<0>(output_padding),
|
||||||
|
std::get<1>(output_padding),
|
||||||
|
std::get<2>(output_padding)},
|
||||||
groups,
|
groups,
|
||||||
s);
|
s);
|
||||||
}
|
}
|
||||||
|
@ -3609,6 +3609,7 @@ void init_ops(nb::module_& m) {
|
|||||||
"stride"_a = 1,
|
"stride"_a = 1,
|
||||||
"padding"_a = 0,
|
"padding"_a = 0,
|
||||||
"dilation"_a = 1,
|
"dilation"_a = 1,
|
||||||
|
"output_padding"_a = 0,
|
||||||
"groups"_a = 1,
|
"groups"_a = 1,
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
@ -3623,6 +3624,7 @@ void init_ops(nb::module_& m) {
|
|||||||
stride (int, optional): Kernel stride. Default: ``1``.
|
stride (int, optional): Kernel stride. Default: ``1``.
|
||||||
padding (int, optional): Input padding. Default: ``0``.
|
padding (int, optional): Input padding. Default: ``0``.
|
||||||
dilation (int, optional): Kernel dilation. Default: ``1``.
|
dilation (int, optional): Kernel dilation. Default: ``1``.
|
||||||
|
output_padding (int, optional): Output padding. Default: ``0``.
|
||||||
groups (int, optional): Input feature groups. Default: ``1``.
|
groups (int, optional): Input feature groups. Default: ``1``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -3635,11 +3637,13 @@ void init_ops(nb::module_& m) {
|
|||||||
const std::variant<int, std::pair<int, int>>& stride,
|
const std::variant<int, std::pair<int, int>>& stride,
|
||||||
const std::variant<int, std::pair<int, int>>& padding,
|
const std::variant<int, std::pair<int, int>>& padding,
|
||||||
const std::variant<int, std::pair<int, int>>& dilation,
|
const std::variant<int, std::pair<int, int>>& dilation,
|
||||||
|
const std::variant<int, std::pair<int, int>>& output_padding,
|
||||||
int groups,
|
int groups,
|
||||||
mx::StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
std::pair<int, int> stride_pair{1, 1};
|
std::pair<int, int> stride_pair{1, 1};
|
||||||
std::pair<int, int> padding_pair{0, 0};
|
std::pair<int, int> padding_pair{0, 0};
|
||||||
std::pair<int, int> dilation_pair{1, 1};
|
std::pair<int, int> dilation_pair{1, 1};
|
||||||
|
std::pair<int, int> output_padding_pair{0, 0};
|
||||||
|
|
||||||
if (auto pv = std::get_if<int>(&stride); pv) {
|
if (auto pv = std::get_if<int>(&stride); pv) {
|
||||||
stride_pair = std::pair<int, int>{*pv, *pv};
|
stride_pair = std::pair<int, int>{*pv, *pv};
|
||||||
@ -3659,14 +3663,28 @@ void init_ops(nb::module_& m) {
|
|||||||
dilation_pair = std::get<std::pair<int, int>>(dilation);
|
dilation_pair = std::get<std::pair<int, int>>(dilation);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (auto pv = std::get_if<int>(&output_padding); pv) {
|
||||||
|
output_padding_pair = std::pair<int, int>{*pv, *pv};
|
||||||
|
} else {
|
||||||
|
output_padding_pair = std::get<std::pair<int, int>>(output_padding);
|
||||||
|
}
|
||||||
|
|
||||||
return mx::conv_transpose2d(
|
return mx::conv_transpose2d(
|
||||||
input, weight, stride_pair, padding_pair, dilation_pair, groups, s);
|
input,
|
||||||
|
weight,
|
||||||
|
stride_pair,
|
||||||
|
padding_pair,
|
||||||
|
dilation_pair,
|
||||||
|
output_padding_pair,
|
||||||
|
groups,
|
||||||
|
s);
|
||||||
},
|
},
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
"stride"_a = 1,
|
"stride"_a = 1,
|
||||||
"padding"_a = 0,
|
"padding"_a = 0,
|
||||||
"dilation"_a = 1,
|
"dilation"_a = 1,
|
||||||
|
"output_padding"_a = 0,
|
||||||
"groups"_a = 1,
|
"groups"_a = 1,
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
@ -3689,6 +3707,9 @@ void init_ops(nb::module_& m) {
|
|||||||
dilation (int or tuple(int), optional): :obj:`tuple` of size 2 with
|
dilation (int or tuple(int), optional): :obj:`tuple` of size 2 with
|
||||||
kernel dilation. All spatial dimensions get the same dilation
|
kernel dilation. All spatial dimensions get the same dilation
|
||||||
if only one number is specified. Default: ``1``
|
if only one number is specified. Default: ``1``
|
||||||
|
output_padding (int or tuple(int), optional): :obj:`tuple` of size 2 with
|
||||||
|
output padding. All spatial dimensions get the same output
|
||||||
|
padding if only one number is specified. Default: ``0``.
|
||||||
groups (int, optional): input feature groups. Default: ``1``.
|
groups (int, optional): input feature groups. Default: ``1``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -3701,11 +3722,13 @@ void init_ops(nb::module_& m) {
|
|||||||
const std::variant<int, std::tuple<int, int, int>>& stride,
|
const std::variant<int, std::tuple<int, int, int>>& stride,
|
||||||
const std::variant<int, std::tuple<int, int, int>>& padding,
|
const std::variant<int, std::tuple<int, int, int>>& padding,
|
||||||
const std::variant<int, std::tuple<int, int, int>>& dilation,
|
const std::variant<int, std::tuple<int, int, int>>& dilation,
|
||||||
|
const std::variant<int, std::tuple<int, int, int>>& output_padding,
|
||||||
int groups,
|
int groups,
|
||||||
mx::StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
std::tuple<int, int, int> stride_tuple{1, 1, 1};
|
std::tuple<int, int, int> stride_tuple{1, 1, 1};
|
||||||
std::tuple<int, int, int> padding_tuple{0, 0, 0};
|
std::tuple<int, int, int> padding_tuple{0, 0, 0};
|
||||||
std::tuple<int, int, int> dilation_tuple{1, 1, 1};
|
std::tuple<int, int, int> dilation_tuple{1, 1, 1};
|
||||||
|
std::tuple<int, int, int> output_padding_tuple{0, 0, 0};
|
||||||
|
|
||||||
if (auto pv = std::get_if<int>(&stride); pv) {
|
if (auto pv = std::get_if<int>(&stride); pv) {
|
||||||
stride_tuple = std::tuple<int, int, int>{*pv, *pv, *pv};
|
stride_tuple = std::tuple<int, int, int>{*pv, *pv, *pv};
|
||||||
@ -3725,12 +3748,20 @@ void init_ops(nb::module_& m) {
|
|||||||
dilation_tuple = std::get<std::tuple<int, int, int>>(dilation);
|
dilation_tuple = std::get<std::tuple<int, int, int>>(dilation);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (auto pv = std::get_if<int>(&output_padding); pv) {
|
||||||
|
output_padding_tuple = std::tuple<int, int, int>{*pv, *pv, *pv};
|
||||||
|
} else {
|
||||||
|
output_padding_tuple =
|
||||||
|
std::get<std::tuple<int, int, int>>(output_padding);
|
||||||
|
}
|
||||||
|
|
||||||
return mx::conv_transpose3d(
|
return mx::conv_transpose3d(
|
||||||
input,
|
input,
|
||||||
weight,
|
weight,
|
||||||
stride_tuple,
|
stride_tuple,
|
||||||
padding_tuple,
|
padding_tuple,
|
||||||
dilation_tuple,
|
dilation_tuple,
|
||||||
|
output_padding_tuple,
|
||||||
groups,
|
groups,
|
||||||
s);
|
s);
|
||||||
},
|
},
|
||||||
@ -3739,6 +3770,7 @@ void init_ops(nb::module_& m) {
|
|||||||
"stride"_a = 1,
|
"stride"_a = 1,
|
||||||
"padding"_a = 0,
|
"padding"_a = 0,
|
||||||
"dilation"_a = 1,
|
"dilation"_a = 1,
|
||||||
|
"output_padding"_a = 0,
|
||||||
"groups"_a = 1,
|
"groups"_a = 1,
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
@ -3761,6 +3793,9 @@ void init_ops(nb::module_& m) {
|
|||||||
dilation (int or tuple(int), optional): :obj:`tuple` of size 3 with
|
dilation (int or tuple(int), optional): :obj:`tuple` of size 3 with
|
||||||
kernel dilation. All spatial dimensions get the same dilation
|
kernel dilation. All spatial dimensions get the same dilation
|
||||||
if only one number is specified. Default: ``1``
|
if only one number is specified. Default: ``1``
|
||||||
|
output_padding (int or tuple(int), optional): :obj:`tuple` of size 3 with
|
||||||
|
output padding. All spatial dimensions get the same output
|
||||||
|
padding if only one number is specified. Default: ``0``.
|
||||||
groups (int, optional): input feature groups. Default: ``1``.
|
groups (int, optional): input feature groups. Default: ``1``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -3777,6 +3812,7 @@ void init_ops(nb::module_& m) {
|
|||||||
std::pair<std::vector<int>, std::vector<int>>>& padding,
|
std::pair<std::vector<int>, std::vector<int>>>& padding,
|
||||||
const std::variant<int, std::vector<int>>& kernel_dilation,
|
const std::variant<int, std::vector<int>>& kernel_dilation,
|
||||||
const std::variant<int, std::vector<int>>& input_dilation,
|
const std::variant<int, std::vector<int>>& input_dilation,
|
||||||
|
const std::variant<int, std::vector<int>>& output_padding,
|
||||||
int groups,
|
int groups,
|
||||||
bool flip,
|
bool flip,
|
||||||
mx::StreamOrDevice s) {
|
mx::StreamOrDevice s) {
|
||||||
@ -3785,6 +3821,7 @@ void init_ops(nb::module_& m) {
|
|||||||
std::vector<int> padding_hi_vec;
|
std::vector<int> padding_hi_vec;
|
||||||
std::vector<int> kernel_dilation_vec;
|
std::vector<int> kernel_dilation_vec;
|
||||||
std::vector<int> input_dilation_vec;
|
std::vector<int> input_dilation_vec;
|
||||||
|
std::vector<int> output_padding_vec;
|
||||||
|
|
||||||
if (auto pv = std::get_if<int>(&stride); pv) {
|
if (auto pv = std::get_if<int>(&stride); pv) {
|
||||||
stride_vec.push_back(*pv);
|
stride_vec.push_back(*pv);
|
||||||
@ -3817,6 +3854,12 @@ void init_ops(nb::module_& m) {
|
|||||||
input_dilation_vec = std::get<std::vector<int>>(input_dilation);
|
input_dilation_vec = std::get<std::vector<int>>(input_dilation);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (auto pv = std::get_if<int>(&output_padding); pv) {
|
||||||
|
output_padding_vec.push_back(*pv);
|
||||||
|
} else {
|
||||||
|
output_padding_vec = std::get<std::vector<int>>(output_padding);
|
||||||
|
}
|
||||||
|
|
||||||
return mx::conv_general(
|
return mx::conv_general(
|
||||||
/* array input = */ std::move(input),
|
/* array input = */ std::move(input),
|
||||||
/* array weight = */ std::move(weight),
|
/* array weight = */ std::move(weight),
|
||||||
@ -3827,6 +3870,8 @@ void init_ops(nb::module_& m) {
|
|||||||
std::move(kernel_dilation_vec),
|
std::move(kernel_dilation_vec),
|
||||||
/* std::vector<int> input_dilation = */
|
/* std::vector<int> input_dilation = */
|
||||||
std::move(input_dilation_vec),
|
std::move(input_dilation_vec),
|
||||||
|
/* std::vector<int> output_padding = */
|
||||||
|
std::move(output_padding_vec),
|
||||||
/* int groups = */ groups,
|
/* int groups = */ groups,
|
||||||
/* bool flip = */ flip,
|
/* bool flip = */ flip,
|
||||||
s);
|
s);
|
||||||
@ -3837,6 +3882,7 @@ void init_ops(nb::module_& m) {
|
|||||||
"padding"_a = 0,
|
"padding"_a = 0,
|
||||||
"kernel_dilation"_a = 1,
|
"kernel_dilation"_a = 1,
|
||||||
"input_dilation"_a = 1,
|
"input_dilation"_a = 1,
|
||||||
|
"output_padding"_a = 0,
|
||||||
"groups"_a = 1,
|
"groups"_a = 1,
|
||||||
"flip"_a = false,
|
"flip"_a = false,
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
@ -3861,6 +3907,9 @@ void init_ops(nb::module_& m) {
|
|||||||
input_dilation (int or list(int), optional): :obj:`list` with
|
input_dilation (int or list(int), optional): :obj:`list` with
|
||||||
input dilation. All spatial dimensions get the same dilation
|
input dilation. All spatial dimensions get the same dilation
|
||||||
if only one number is specified. Default: ``1``
|
if only one number is specified. Default: ``1``
|
||||||
|
output_padding (int or list(int), optional): :obj:`list` with
|
||||||
|
output padding. All spatial dimensions get the same padding
|
||||||
|
if only one number is specified. Default: ``0``
|
||||||
groups (int, optional): Input feature groups. Default: ``1``.
|
groups (int, optional): Input feature groups. Default: ``1``.
|
||||||
flip (bool, optional): Flip the order in which the spatial dimensions of
|
flip (bool, optional): Flip the order in which the spatial dimensions of
|
||||||
the weights are processed. Performs the cross-correlation operator when
|
the weights are processed. Performs the cross-correlation operator when
|
||||||
|
@ -3911,4 +3911,64 @@ TEST_CASE("test bitwise shift operations") {
|
|||||||
|
|
||||||
CHECK_EQ(right_shift_bool_result.dtype(), uint8);
|
CHECK_EQ(right_shift_bool_result.dtype(), uint8);
|
||||||
CHECK(array_equal(right_shift_bool_result, full({4}, 0, uint8)).item<bool>());
|
CHECK(array_equal(right_shift_bool_result, full({4}, 0, uint8)).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test conv_transpose1d with output_padding") {
|
||||||
|
auto in = array({1.0, 2.0, 3.0}, {1, 1, 3});
|
||||||
|
auto wt = array({1.0, 1.0}, {1, 1, 2});
|
||||||
|
int stride = 2;
|
||||||
|
int padding = 0;
|
||||||
|
int output_padding = 1;
|
||||||
|
|
||||||
|
auto out = conv_transpose1d(in, wt, stride, padding, 1, output_padding, 1);
|
||||||
|
auto expected = array({1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 0.0}, {1, 1, 7});
|
||||||
|
CHECK(array_equal(out, expected).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test conv_transpose2d with output_padding") {
|
||||||
|
auto in = array({1.0, 2.0, 3.0, 4.0}, {1, 1, 2, 2});
|
||||||
|
auto wt = array({1.0, 1.0, 1.0, 1.0}, {1, 1, 2, 2});
|
||||||
|
std::pair<int, int> stride{2, 2};
|
||||||
|
std::pair<int, int> padding{0, 0};
|
||||||
|
std::pair<int, int> output_padding{1, 1};
|
||||||
|
|
||||||
|
auto out =
|
||||||
|
conv_transpose2d(in, wt, stride, padding, {1, 1}, output_padding, 1);
|
||||||
|
auto expected = array(
|
||||||
|
{1.0,
|
||||||
|
1.0,
|
||||||
|
2.0,
|
||||||
|
2.0,
|
||||||
|
0.0,
|
||||||
|
3.0,
|
||||||
|
3.0,
|
||||||
|
4.0,
|
||||||
|
4.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0},
|
||||||
|
{1, 1, 5, 5});
|
||||||
|
CHECK(array_equal(out, expected).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test conv_transpose3d with output_padding") {
|
||||||
|
auto in = array({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}, {1, 1, 2, 2, 2});
|
||||||
|
auto wt = array({1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}, {1, 1, 2, 2, 2});
|
||||||
|
std::tuple<int, int, int> stride{2, 2, 2};
|
||||||
|
std::tuple<int, int, int> padding{0, 0, 0};
|
||||||
|
std::tuple<int, int, int> output_padding{1, 1, 1};
|
||||||
|
|
||||||
|
auto out =
|
||||||
|
conv_transpose3d(in, wt, stride, padding, {1, 1, 1}, output_padding, 1);
|
||||||
|
auto expected = array(
|
||||||
|
{1.0, 1.0, 2.0, 2.0, 0.0, 3.0, 3.0, 4.0, 4.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0, 0.0, 5.0, 5.0, 6.0, 6.0, 0.0, 7.0, 7.0, 8.0, 8.0,
|
||||||
|
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0},
|
||||||
|
{1, 1, 5, 5, 5});
|
||||||
|
CHECK(array_equal(out, expected).item<bool>());
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user