mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Transposed Convolution (#1245)
* initial implementation for conv_transpose ran pre-commit implemented conv_transpose updated conv_general docstring updated conv_general docstring updated code comments removed commented run_conv_checks updated acknowledgments added missing entry to ops.rst added op to nn.layers resolved merge conflicts * removed ConvolutionTranspose primitive as suggested by reviewer removed ConvolutionTranspose primitive as suggested by reviewer * remove transpose flag, add another test --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
committed by
GitHub
parent
ba3e913c7a
commit
efeb9c0f02
87
mlx/ops.cpp
87
mlx/ops.cpp
@@ -3298,6 +3298,93 @@ array conv3d(
|
||||
s);
|
||||
}
|
||||
|
||||
// Helper function for transposed convolutions
|
||||
array conv_transpose_general(
|
||||
const array& input,
|
||||
const array& weight,
|
||||
std::vector<int> stride,
|
||||
std::vector<int> padding,
|
||||
std::vector<int> dilation,
|
||||
int groups,
|
||||
StreamOrDevice s) {
|
||||
std::vector<int> padding_lo(padding.size());
|
||||
std::vector<int> padding_hi(padding.size());
|
||||
for (int i = 0; i < padding.size(); ++i) {
|
||||
int wt_size = 1 + dilation[i] * (weight.shape(1 + i) - 1);
|
||||
padding_lo[i] = wt_size - padding[i] - 1;
|
||||
|
||||
int conv_output_shape = (input.shape(i + 1) - 1) * stride[i] -
|
||||
2 * padding[i] + dilation[i] * (weight.shape(i + 1) - 1) + 1;
|
||||
|
||||
int in_size = 1 + (conv_output_shape - 1);
|
||||
int out_size = 1 + stride[i] * (input.shape(1 + i) - 1);
|
||||
padding_hi[i] = in_size - out_size + padding[i];
|
||||
}
|
||||
|
||||
return conv_general(
|
||||
/* const array& input = */ input,
|
||||
/* const array& weight = */ weight,
|
||||
/* std::vector<int> stride = */ std::vector(stride.size(), 1),
|
||||
/* std::vector<int> padding_lo = */ std::move(padding_lo),
|
||||
/* std::vector<int> padding_hi = */ std::move(padding_hi),
|
||||
/* std::vector<int> kernel_dilation = */ std::move(dilation),
|
||||
/* std::vector<int> input_dilation = */ std::move(stride),
|
||||
/* int groups = */ groups,
|
||||
/* bool flip = */ true,
|
||||
s);
|
||||
}
|
||||
|
||||
/** 1D transposed convolution with a filter */
|
||||
array conv_transpose1d(
|
||||
const array& in_,
|
||||
const array& wt_,
|
||||
int stride /* = 1 */,
|
||||
int padding /* = 0 */,
|
||||
int dilation /* = 1 */,
|
||||
int groups /* = 1 */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return conv_transpose_general(
|
||||
in_, wt_, {stride}, {padding}, {dilation}, groups, s);
|
||||
}
|
||||
|
||||
/** 2D transposed convolution with a filter */
|
||||
array conv_transpose2d(
|
||||
const array& in_,
|
||||
const array& wt_,
|
||||
const std::pair<int, int>& stride /* = {1, 1} */,
|
||||
const std::pair<int, int>& padding /* = {0, 0} */,
|
||||
const std::pair<int, int>& dilation /* = {1, 1} */,
|
||||
int groups /* = 1 */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return conv_transpose_general(
|
||||
in_,
|
||||
wt_,
|
||||
{stride.first, stride.second},
|
||||
{padding.first, padding.second},
|
||||
{dilation.first, dilation.second},
|
||||
groups,
|
||||
s);
|
||||
}
|
||||
|
||||
/** 3D transposed convolution with a filter */
|
||||
array conv_transpose3d(
|
||||
const array& in_,
|
||||
const array& wt_,
|
||||
const std::tuple<int, int, int>& stride /* = {1, 1, 1} */,
|
||||
const std::tuple<int, int, int>& padding /* = {0, 0, 0} */,
|
||||
const std::tuple<int, int, int>& dilation /* = {1, 1, 1} */,
|
||||
int groups /* = 1 */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
return conv_transpose_general(
|
||||
in_,
|
||||
wt_,
|
||||
{std::get<0>(stride), std::get<1>(stride), std::get<2>(stride)},
|
||||
{std::get<0>(padding), std::get<1>(padding), std::get<2>(padding)},
|
||||
{std::get<0>(dilation), std::get<1>(dilation), std::get<2>(dilation)},
|
||||
groups,
|
||||
s);
|
||||
}
|
||||
|
||||
/** General convolution with a filter */
|
||||
array conv_general(
|
||||
array in,
|
||||
|
||||
Reference in New Issue
Block a user