Fix a couple of slicing bugs (#1827)

* fix a few bugs

* fix conv grad

* speedup test

* comment
This commit is contained in:
Awni Hannun
2025-02-05 19:50:08 -08:00
committed by GitHub
parent 9174606d4c
commit af1b725fda
14 changed files with 170 additions and 107 deletions

View File

@@ -1229,58 +1229,45 @@ std::vector<array> Convolution::vjp(
in, wt, cotan, kernel_strides_, padding_, stream());
grads.push_back(grad);
} else {
if (flip_) {
auto padding = padding_;
for (int i = 0; i < padding.size(); i++) {
int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1);
padding[i] = wt_size - padding_[i] - 1;
}
std::vector<int> padding_lo = padding_;
std::vector<int> padding_hi = padding_;
auto cotan_trans = group_transpose(cotan, -1, 0, -1);
auto in_trans = swapaxes(in, 0, -1, stream());
auto grad_trans = conv_general(
/* const array& input = */ cotan_trans,
/* const array& weight = */ in_trans,
/* std::vector<int> stride = */ kernel_dilation_,
/* std::vector<int> padding_lo = */ padding,
/* std::vector<int> padding_hi = */ padding,
/* std::vector<int> kernel_dilation = */ input_dilation_,
/* std::vector<int> input_dilation = */ kernel_strides_,
/* int groups = */ groups_,
/* bool flip = */ false,
stream());
if (groups_ > 1) {
grads.push_back(group_transpose(grad_trans, -1, 0, -2));
} else {
grads.push_back(grad_trans);
}
} else {
std::vector<int> padding_lo = padding_;
std::vector<int> padding_hi = padding_;
for (int i = 0; i < padding_hi.size(); ++i) {
int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1);
int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1);
int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1);
padding_hi[i] = out_size - in_size + wt_size - padding_[i] - 1;
}
auto cotan_trans = swapaxes(cotan, 0, -1, stream());
auto in_trans = group_transpose(in, -1, 0, -1);
auto grad_trans = conv_general(
/* const array& input = */ in_trans,
/* const array& weight = */ cotan_trans,
/* std::vector<int> stride = */ kernel_dilation_,
/* std::vector<int> padding_lo = */ padding_lo,
/* std::vector<int> padding_hi = */ padding_hi,
/* std::vector<int> kernel_dilation = */ kernel_strides_,
/* std::vector<int> input_dilation = */ input_dilation_,
/* int groups = */ groups_,
/* bool flip = */ false,
stream());
grads.push_back(swapaxes(grad_trans, 0, -1, stream()));
for (int i = 0; i < padding_hi.size(); ++i) {
int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1);
int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1);
int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1);
padding_hi[i] = out_size - in_size + wt_size - padding_[i] - 1;
}
auto cotan_trans = swapaxes(cotan, 0, -1, stream());
auto in_trans = group_transpose(in, -1, 0, -1);
auto grad_trans = conv_general(
/* const array& input = */ in_trans,
/* const array& weight = */ cotan_trans,
/* std::vector<int> stride = */ kernel_dilation_,
/* std::vector<int> padding_lo = */ padding_lo,
/* std::vector<int> padding_hi = */ padding_hi,
/* std::vector<int> kernel_dilation = */ kernel_strides_,
/* std::vector<int> input_dilation = */ input_dilation_,
/* int groups = */ groups_,
/* bool flip = */ false,
stream());
if (flip_) {
auto start = Shape(grad_trans.ndim(), 0);
auto stop = Shape(grad_trans.ndim(), 0);
auto strides = Shape(grad_trans.ndim(), 1);
for (int i = 0; i < stop.size(); ++i) {
if (i >= 1 && i < stop.size() - 1) {
start[i] = grad_trans.shape(i);
stop[i] = -start[i] - 1;
strides[i] = -1;
} else {
stop[i] = grad_trans.shape(i);
}
}
grad_trans = slice(grad_trans, start, stop, strides, stream());
}
grads.push_back(swapaxes(grad_trans, 0, -1, stream()));
}
}
}