Transposed Convolution (#1245)

* initial implementation for conv_transpose

ran pre-commit

implemented conv_transpose

updated conv_general docstring

updated conv_general docstring

updated code comments

removed commented run_conv_checks

updated acknowledgments

added missing entry to ops.rst

added op to nn.layers

resolved merge conflicts

* removed ConvolutionTranspose primitive as suggested by reviewer

removed ConvolutionTranspose primitive as suggested by reviewer

* remove transpose flag, add another test

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Max-Heinrich Laves
2024-09-07 04:52:38 +02:00
committed by GitHub
parent ba3e913c7a
commit efeb9c0f02
15 changed files with 1337 additions and 45 deletions

View File

@@ -3298,6 +3298,93 @@ array conv3d(
s);
}
// Helper function for transposed convolutions
array conv_transpose_general(
const array& input,
const array& weight,
std::vector<int> stride,
std::vector<int> padding,
std::vector<int> dilation,
int groups,
StreamOrDevice s) {
std::vector<int> padding_lo(padding.size());
std::vector<int> padding_hi(padding.size());
for (int i = 0; i < padding.size(); ++i) {
int wt_size = 1 + dilation[i] * (weight.shape(1 + i) - 1);
padding_lo[i] = wt_size - padding[i] - 1;
int conv_output_shape = (input.shape(i + 1) - 1) * stride[i] -
2 * padding[i] + dilation[i] * (weight.shape(i + 1) - 1) + 1;
int in_size = 1 + (conv_output_shape - 1);
int out_size = 1 + stride[i] * (input.shape(1 + i) - 1);
padding_hi[i] = in_size - out_size + padding[i];
}
return conv_general(
/* const array& input = */ input,
/* const array& weight = */ weight,
/* std::vector<int> stride = */ std::vector(stride.size(), 1),
/* std::vector<int> padding_lo = */ std::move(padding_lo),
/* std::vector<int> padding_hi = */ std::move(padding_hi),
/* std::vector<int> kernel_dilation = */ std::move(dilation),
/* std::vector<int> input_dilation = */ std::move(stride),
/* int groups = */ groups,
/* bool flip = */ true,
s);
}
/** 1D transposed convolution with a filter */
array conv_transpose1d(
const array& in_,
const array& wt_,
int stride /* = 1 */,
int padding /* = 0 */,
int dilation /* = 1 */,
int groups /* = 1 */,
StreamOrDevice s /* = {} */) {
return conv_transpose_general(
in_, wt_, {stride}, {padding}, {dilation}, groups, s);
}
/** 2D transposed convolution with a filter */
array conv_transpose2d(
const array& in_,
const array& wt_,
const std::pair<int, int>& stride /* = {1, 1} */,
const std::pair<int, int>& padding /* = {0, 0} */,
const std::pair<int, int>& dilation /* = {1, 1} */,
int groups /* = 1 */,
StreamOrDevice s /* = {} */) {
return conv_transpose_general(
in_,
wt_,
{stride.first, stride.second},
{padding.first, padding.second},
{dilation.first, dilation.second},
groups,
s);
}
/** 3D transposed convolution with a filter */
array conv_transpose3d(
const array& in_,
const array& wt_,
const std::tuple<int, int, int>& stride /* = {1, 1, 1} */,
const std::tuple<int, int, int>& padding /* = {0, 0, 0} */,
const std::tuple<int, int, int>& dilation /* = {1, 1, 1} */,
int groups /* = 1 */,
StreamOrDevice s /* = {} */) {
return conv_transpose_general(
in_,
wt_,
{std::get<0>(stride), std::get<1>(stride), std::get<2>(stride)},
{std::get<0>(padding), std::get<1>(padding), std::get<2>(padding)},
{std::get<0>(dilation), std::get<1>(dilation), std::get<2>(dilation)},
groups,
s);
}
/** General convolution with a filter */
array conv_general(
array in,