fix conv grad (#2187)

This commit is contained in:
Awni Hannun
2025-05-15 19:20:36 -07:00
committed by GitHub
parent a2cadb8218
commit 602f43e3d1
2 changed files with 33 additions and 7 deletions

View File

@@ -1116,13 +1116,11 @@ array conv_weight_backward_patches(
// Pad input
std::vector<int> padded_axes(in.ndim() - 2, 0);
std::iota(padded_axes.begin(), padded_axes.end(), 1);
Shape padding_lo_(padding_lo.begin(), padding_lo.end());
Shape padding_hi_(padding_hi.begin(), padding_hi.end());
auto in_padded =
pad(in,
padded_axes,
padding_lo_,
padding_hi_,
Shape(padding_lo),
Shape(padding_hi),
array(0, in.dtype()),
"constant",
s);
@@ -1274,8 +1272,14 @@ std::vector<array> Convolution::vjp(
in, wt, cotan, kernel_strides_, padding_lo_, padding_hi_, stream());
grads.push_back(grad);
} else {
std::vector<int> padding_lo = padding_lo_;
std::vector<int> padding_hi = padding_hi_;
auto padding_hi = padding_lo_;
for (int i = 0; i < padding_hi.size(); ++i) {
int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1);
int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1);
int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1);
padding_hi[i] = out_size - in_size + wt_size - padding_hi[i] - 1;
}
auto cotan_trans = swapaxes(cotan, 0, -1, stream());
auto in_trans = group_transpose(in, -1, 0, -1);
@@ -1284,7 +1288,7 @@ std::vector<array> Convolution::vjp(
/* const array& input = */ in_trans,
/* const array& weight = */ cotan_trans,
/* std::vector<int> stride = */ kernel_dilation_,
/* std::vector<int> padding_lo = */ padding_lo,
/* std::vector<int> padding_lo = */ padding_lo_,
/* std::vector<int> padding_hi = */ padding_hi,
/* std::vector<int> kernel_dilation = */ kernel_strides_,
/* std::vector<int> input_dilation = */ input_dilation_,