mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-15 17:39:05 +08:00
Added conv_transpose fixes
This commit is contained in:
@@ -3609,6 +3609,7 @@ void init_ops(nb::module_& m) {
|
||||
"stride"_a = 1,
|
||||
"padding"_a = 0,
|
||||
"dilation"_a = 1,
|
||||
"output_padding"_a = 0,
|
||||
"groups"_a = 1,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
@@ -3623,6 +3624,7 @@ void init_ops(nb::module_& m) {
|
||||
stride (int, optional): Kernel stride. Default: ``1``.
|
||||
padding (int, optional): Input padding. Default: ``0``.
|
||||
dilation (int, optional): Kernel dilation. Default: ``1``.
|
||||
output_padding (int, optional): Output padding. Default: ``0``.
|
||||
groups (int, optional): Input feature groups. Default: ``1``.
|
||||
|
||||
Returns:
|
||||
@@ -3635,11 +3637,13 @@ void init_ops(nb::module_& m) {
|
||||
const std::variant<int, std::pair<int, int>>& stride,
|
||||
const std::variant<int, std::pair<int, int>>& padding,
|
||||
const std::variant<int, std::pair<int, int>>& dilation,
|
||||
const std::variant<int, std::pair<int, int>>& output_padding,
|
||||
int groups,
|
||||
mx::StreamOrDevice s) {
|
||||
std::pair<int, int> stride_pair{1, 1};
|
||||
std::pair<int, int> padding_pair{0, 0};
|
||||
std::pair<int, int> dilation_pair{1, 1};
|
||||
std::pair<int, int> output_padding_pair{0, 0};
|
||||
|
||||
if (auto pv = std::get_if<int>(&stride); pv) {
|
||||
stride_pair = std::pair<int, int>{*pv, *pv};
|
||||
@@ -3659,14 +3663,28 @@ void init_ops(nb::module_& m) {
|
||||
dilation_pair = std::get<std::pair<int, int>>(dilation);
|
||||
}
|
||||
|
||||
if (auto pv = std::get_if<int>(&output_padding); pv) {
|
||||
output_padding_pair = std::pair<int, int>{*pv, *pv};
|
||||
} else {
|
||||
output_padding_pair = std::get<std::pair<int, int>>(output_padding);
|
||||
}
|
||||
|
||||
return mx::conv_transpose2d(
|
||||
input, weight, stride_pair, padding_pair, dilation_pair, groups, s);
|
||||
input,
|
||||
weight,
|
||||
stride_pair,
|
||||
padding_pair,
|
||||
dilation_pair,
|
||||
output_padding_pair,
|
||||
groups,
|
||||
s);
|
||||
},
|
||||
nb::arg(),
|
||||
nb::arg(),
|
||||
"stride"_a = 1,
|
||||
"padding"_a = 0,
|
||||
"dilation"_a = 1,
|
||||
"output_padding"_a = 0,
|
||||
"groups"_a = 1,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
@@ -3689,6 +3707,9 @@ void init_ops(nb::module_& m) {
|
||||
dilation (int or tuple(int), optional): :obj:`tuple` of size 2 with
|
||||
kernel dilation. All spatial dimensions get the same dilation
|
||||
if only one number is specified. Default: ``1``
|
||||
output_padding (int or tuple(int), optional): :obj:`tuple` of size 2 with
|
||||
output padding. All spatial dimensions get the same output
|
||||
padding if only one number is specified. Default: ``0``.
|
||||
groups (int, optional): input feature groups. Default: ``1``.
|
||||
|
||||
Returns:
|
||||
@@ -3701,11 +3722,13 @@ void init_ops(nb::module_& m) {
|
||||
const std::variant<int, std::tuple<int, int, int>>& stride,
|
||||
const std::variant<int, std::tuple<int, int, int>>& padding,
|
||||
const std::variant<int, std::tuple<int, int, int>>& dilation,
|
||||
const std::variant<int, std::tuple<int, int, int>>& output_padding,
|
||||
int groups,
|
||||
mx::StreamOrDevice s) {
|
||||
std::tuple<int, int, int> stride_tuple{1, 1, 1};
|
||||
std::tuple<int, int, int> padding_tuple{0, 0, 0};
|
||||
std::tuple<int, int, int> dilation_tuple{1, 1, 1};
|
||||
std::tuple<int, int, int> output_padding_tuple{0, 0, 0};
|
||||
|
||||
if (auto pv = std::get_if<int>(&stride); pv) {
|
||||
stride_tuple = std::tuple<int, int, int>{*pv, *pv, *pv};
|
||||
@@ -3725,12 +3748,20 @@ void init_ops(nb::module_& m) {
|
||||
dilation_tuple = std::get<std::tuple<int, int, int>>(dilation);
|
||||
}
|
||||
|
||||
if (auto pv = std::get_if<int>(&output_padding); pv) {
|
||||
output_padding_tuple = std::tuple<int, int, int>{*pv, *pv, *pv};
|
||||
} else {
|
||||
output_padding_tuple =
|
||||
std::get<std::tuple<int, int, int>>(output_padding);
|
||||
}
|
||||
|
||||
return mx::conv_transpose3d(
|
||||
input,
|
||||
weight,
|
||||
stride_tuple,
|
||||
padding_tuple,
|
||||
dilation_tuple,
|
||||
output_padding_tuple,
|
||||
groups,
|
||||
s);
|
||||
},
|
||||
@@ -3739,6 +3770,7 @@ void init_ops(nb::module_& m) {
|
||||
"stride"_a = 1,
|
||||
"padding"_a = 0,
|
||||
"dilation"_a = 1,
|
||||
"output_padding"_a = 0,
|
||||
"groups"_a = 1,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
@@ -3761,6 +3793,9 @@ void init_ops(nb::module_& m) {
|
||||
dilation (int or tuple(int), optional): :obj:`tuple` of size 3 with
|
||||
kernel dilation. All spatial dimensions get the same dilation
|
||||
if only one number is specified. Default: ``1``
|
||||
output_padding (int or tuple(int), optional): :obj:`tuple` of size 3 with
|
||||
output padding. All spatial dimensions get the same output
|
||||
padding if only one number is specified. Default: ``0``.
|
||||
groups (int, optional): input feature groups. Default: ``1``.
|
||||
|
||||
Returns:
|
||||
@@ -3777,6 +3812,7 @@ void init_ops(nb::module_& m) {
|
||||
std::pair<std::vector<int>, std::vector<int>>>& padding,
|
||||
const std::variant<int, std::vector<int>>& kernel_dilation,
|
||||
const std::variant<int, std::vector<int>>& input_dilation,
|
||||
const std::variant<int, std::vector<int>>& output_padding,
|
||||
int groups,
|
||||
bool flip,
|
||||
mx::StreamOrDevice s) {
|
||||
@@ -3785,6 +3821,7 @@ void init_ops(nb::module_& m) {
|
||||
std::vector<int> padding_hi_vec;
|
||||
std::vector<int> kernel_dilation_vec;
|
||||
std::vector<int> input_dilation_vec;
|
||||
std::vector<int> output_padding_vec;
|
||||
|
||||
if (auto pv = std::get_if<int>(&stride); pv) {
|
||||
stride_vec.push_back(*pv);
|
||||
@@ -3817,6 +3854,12 @@ void init_ops(nb::module_& m) {
|
||||
input_dilation_vec = std::get<std::vector<int>>(input_dilation);
|
||||
}
|
||||
|
||||
if (auto pv = std::get_if<int>(&output_padding); pv) {
|
||||
output_padding_vec.push_back(*pv);
|
||||
} else {
|
||||
output_padding_vec = std::get<std::vector<int>>(output_padding);
|
||||
}
|
||||
|
||||
return mx::conv_general(
|
||||
/* array input = */ std::move(input),
|
||||
/* array weight = */ std::move(weight),
|
||||
@@ -3827,6 +3870,8 @@ void init_ops(nb::module_& m) {
|
||||
std::move(kernel_dilation_vec),
|
||||
/* std::vector<int> input_dilation = */
|
||||
std::move(input_dilation_vec),
|
||||
/* std::vector<int> output_padding = */
|
||||
std::move(output_padding_vec),
|
||||
/* int groups = */ groups,
|
||||
/* bool flip = */ flip,
|
||||
s);
|
||||
@@ -3837,6 +3882,7 @@ void init_ops(nb::module_& m) {
|
||||
"padding"_a = 0,
|
||||
"kernel_dilation"_a = 1,
|
||||
"input_dilation"_a = 1,
|
||||
"output_padding"_a = 0,
|
||||
"groups"_a = 1,
|
||||
"flip"_a = false,
|
||||
nb::kw_only(),
|
||||
@@ -3861,6 +3907,9 @@ void init_ops(nb::module_& m) {
|
||||
input_dilation (int or list(int), optional): :obj:`list` with
|
||||
input dilation. All spatial dimensions get the same dilation
|
||||
if only one number is specified. Default: ``1``
|
||||
output_padding (int or list(int), optional): :obj:`list` with
|
||||
output padding. All spatial dimensions get the same padding
|
||||
if only one number is specified. Default: ``0``
|
||||
groups (int, optional): Input feature groups. Default: ``1``.
|
||||
flip (bool, optional): Flip the order in which the spatial dimensions of
|
||||
the weights are processed. Performs the cross-correlation operator when
|
||||
|
||||
Reference in New Issue
Block a user