This commit is contained in:
aturker1 2025-04-29 00:40:13 +03:00 committed by Awni Hannun
parent 7942191a64
commit 76def90b73

View File

@ -1240,12 +1240,6 @@ std::vector<array> Convolution::vjp(
std::vector<int> padding_lo = padding_lo_;
std::vector<int> padding_hi = padding_hi_;
for (int i = 0; i < padding_hi.size(); ++i) {
int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1);
int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1);
int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1);
padding_hi[i] = out_size - in_size + wt_size - padding_hi_[i] - 1;
}
auto cotan_trans = swapaxes(cotan, 0, -1, stream());
auto in_trans = group_transpose(in, -1, 0, -1);
@ -1291,7 +1285,8 @@ std::pair<std::vector<array>, std::vector<int>> Convolution::vmap(
in,
w,
kernel_strides_,
padding_,
padding_lo_,
padding_hi_,
kernel_dilation_,
input_dilation_,
groups,