Transposed Convolution (#1245)

* initial implementation for conv_transpose

ran pre-commit

implemented conv_transpose

updated conv_general docstring

updated conv_general docstring

updated code comments

removed commented run_conv_checks

updated acknowledgments

added missing entry to ops.rst

added op to nn.layers

resolved merge conflicts

* removed ConvolutionTranspose primitive as suggested by reviewer

removed ConvolutionTranspose primitive as suggested by reviewer

* remove transpose flag, add another test

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Max-Heinrich Laves
2024-09-07 04:52:38 +02:00
committed by GitHub
parent ba3e913c7a
commit efeb9c0f02
15 changed files with 1337 additions and 45 deletions

View File

@@ -952,8 +952,8 @@ std::vector<array> Convolution::vjp(
/* const array& input = */ cotan,
/* const array& weight = */ wt_trans,
/* std::vector<int> stride = */ input_dilation_,
/* std::vector<int> padding_lo = */ padding_lo_,
/* std::vector<int> padding_hi = */ padding_hi_,
/* std::vector<int> padding_lo = */ padding_lo,
/* std::vector<int> padding_hi = */ padding_hi,
/* std::vector<int> kernel_dilation = */ kernel_dilation_,
/* std::vector<int> input_dilation = */ kernel_strides_,
/* int groups = */ 1,
@@ -990,36 +990,61 @@ std::vector<array> Convolution::vjp(
no_dilation &= (input_dilation_[i] == 1) && (kernel_dilation_[i] == 1);
}
if (no_dilation) {
if (no_dilation && !flip_) {
auto grad = conv_weight_backward_patches(
in, wt, cotan, kernel_strides_, padding_, stream());
grads.push_back(grad);
} else {
std::vector<int> padding_lo = padding_;
std::vector<int> padding_hi = padding_;
for (int i = 0; i < padding_hi.size(); ++i) {
int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1);
int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1);
int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1);
padding_hi[i] = out_size - in_size + wt_size - padding_[i] - 1;
}
auto in_trans = swapaxes(in, 0, -1, stream());
auto cotan_trans = swapaxes(cotan, 0, -1, stream());
auto grad_trans = conv_general(
/* const array& input = */ in_trans,
/* const array& weight = */ cotan_trans,
/* std::vector<int> stride = */ kernel_dilation_,
/* std::vector<int> padding_lo = */ padding_lo,
/* std::vector<int> padding_hi = */ padding_hi,
/* std::vector<int> kernel_dilation = */ kernel_strides_,
/* std::vector<int> input_dilation = */ input_dilation_,
/* int groups = */ 1,
/* bool flip = */ flip_,
stream());
auto grad = swapaxes(grad_trans, 0, -1, stream());
grads.push_back(grad);
auto in_trans = swapaxes(in, 0, -1, stream());
if (flip_) {
auto padding = padding_;
for (int i = 0; i < padding.size(); i++) {
int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1);
padding[i] = wt_size - padding_[i] - 1;
}
auto grad_trans = conv_general(
/* const array& input = */ cotan_trans,
/* const array& weight = */ in_trans,
/* std::vector<int> stride = */ kernel_dilation_,
/* std::vector<int> padding_lo = */ padding,
/* std::vector<int> padding_hi = */ padding,
/* std::vector<int> kernel_dilation = */ input_dilation_,
/* std::vector<int> input_dilation = */ kernel_strides_,
/* int groups = */ 1,
/* bool flip = */ false,
stream());
auto grad = swapaxes(grad_trans, 0, -1, stream());
grads.push_back(grad_trans);
} else {
std::vector<int> padding_lo = padding_;
std::vector<int> padding_hi = padding_;
for (int i = 0; i < padding_hi.size(); ++i) {
int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1);
int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1);
int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1);
padding_hi[i] = out_size - in_size + wt_size - padding_[i] - 1;
}
auto in_trans = swapaxes(in, 0, -1, stream());
auto cotan_trans = swapaxes(cotan, 0, -1, stream());
auto grad_trans = conv_general(
/* const array& input = */ in_trans,
/* const array& weight = */ cotan_trans,
/* std::vector<int> stride = */ kernel_dilation_,
/* std::vector<int> padding_lo = */ padding_lo,
/* std::vector<int> padding_hi = */ padding_hi,
/* std::vector<int> kernel_dilation = */ kernel_strides_,
/* std::vector<int> input_dilation = */ input_dilation_,
/* int groups = */ 1,
/* bool flip = */ false,
stream());
auto grad = swapaxes(grad_trans, 0, -1, stream());
grads.push_back(grad);
}
}
}
}