mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Conv grad with groups + bugfix (#1449)
* fix bug in flipped conv with groups, start of grad for groups * fix * fix * fix + test
This commit is contained in:
@@ -929,16 +929,28 @@ std::vector<array> Convolution::vjp(
|
||||
assert(primals.size() == 2);
|
||||
std::vector<array> grads;
|
||||
|
||||
if (groups_ != 1) {
|
||||
throw std::invalid_argument(
|
||||
"[Convolution] Backward pass not implemented for groups > 1.");
|
||||
}
|
||||
|
||||
// Collect info
|
||||
auto& in = primals[0];
|
||||
auto& wt = primals[1];
|
||||
auto& cotan = cotangents[0];
|
||||
|
||||
auto group_transpose =
|
||||
[this](const array& x, int group_dim, int ax_a, int ax_b) {
|
||||
if (groups_ > 1) {
|
||||
auto shape = x.shape();
|
||||
if (group_dim < 0) {
|
||||
group_dim += shape.size();
|
||||
}
|
||||
shape.insert(shape.begin() + group_dim, groups_);
|
||||
shape[group_dim + 1] = shape[group_dim + 1] / groups_;
|
||||
auto x_trans = swapaxes(
|
||||
reshape(x, std::move(shape), stream()), ax_a, ax_b, stream());
|
||||
return flatten(x_trans, group_dim, group_dim + 1, stream());
|
||||
} else {
|
||||
return swapaxes(x, 0, -1, stream());
|
||||
}
|
||||
};
|
||||
|
||||
for (int a : argnums) {
|
||||
// Grads for input
|
||||
if (a == 0) {
|
||||
@@ -976,8 +988,7 @@ std::vector<array> Convolution::vjp(
|
||||
}
|
||||
}
|
||||
|
||||
auto wt_trans = swapaxes(wt, 0, -1, stream());
|
||||
|
||||
auto wt_trans = group_transpose(wt, 0, 1, -1);
|
||||
auto grad = conv_general(
|
||||
/* const array& input = */ cotan,
|
||||
/* const array& weight = */ wt_trans,
|
||||
@@ -986,7 +997,7 @@ std::vector<array> Convolution::vjp(
|
||||
/* std::vector<int> padding_hi = */ padding_hi,
|
||||
/* std::vector<int> kernel_dilation = */ kernel_dilation_,
|
||||
/* std::vector<int> input_dilation = */ kernel_strides_,
|
||||
/* int groups = */ 1,
|
||||
/* int groups = */ groups_,
|
||||
/* bool flip = */ !flip_,
|
||||
stream());
|
||||
|
||||
@@ -1020,14 +1031,11 @@ std::vector<array> Convolution::vjp(
|
||||
no_dilation &= (input_dilation_[i] == 1) && (kernel_dilation_[i] == 1);
|
||||
}
|
||||
|
||||
if (no_dilation && !flip_) {
|
||||
if (no_dilation && !flip_ && groups_ == 1) {
|
||||
auto grad = conv_weight_backward_patches(
|
||||
in, wt, cotan, kernel_strides_, padding_, stream());
|
||||
grads.push_back(grad);
|
||||
} else {
|
||||
auto cotan_trans = swapaxes(cotan, 0, -1, stream());
|
||||
auto in_trans = swapaxes(in, 0, -1, stream());
|
||||
|
||||
if (flip_) {
|
||||
auto padding = padding_;
|
||||
for (int i = 0; i < padding.size(); i++) {
|
||||
@@ -1035,6 +1043,9 @@ std::vector<array> Convolution::vjp(
|
||||
padding[i] = wt_size - padding_[i] - 1;
|
||||
}
|
||||
|
||||
auto cotan_trans = group_transpose(cotan, -1, 0, -1);
|
||||
auto in_trans = swapaxes(in, 0, -1, stream());
|
||||
|
||||
auto grad_trans = conv_general(
|
||||
/* const array& input = */ cotan_trans,
|
||||
/* const array& weight = */ in_trans,
|
||||
@@ -1043,11 +1054,14 @@ std::vector<array> Convolution::vjp(
|
||||
/* std::vector<int> padding_hi = */ padding,
|
||||
/* std::vector<int> kernel_dilation = */ input_dilation_,
|
||||
/* std::vector<int> input_dilation = */ kernel_strides_,
|
||||
/* int groups = */ 1,
|
||||
/* int groups = */ groups_,
|
||||
/* bool flip = */ false,
|
||||
stream());
|
||||
auto grad = swapaxes(grad_trans, 0, -1, stream());
|
||||
grads.push_back(grad_trans);
|
||||
if (groups_ > 1) {
|
||||
grads.push_back(group_transpose(grad_trans, -1, 0, -2));
|
||||
} else {
|
||||
grads.push_back(grad_trans);
|
||||
}
|
||||
} else {
|
||||
std::vector<int> padding_lo = padding_;
|
||||
std::vector<int> padding_hi = padding_;
|
||||
@@ -1058,9 +1072,9 @@ std::vector<array> Convolution::vjp(
|
||||
int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1);
|
||||
padding_hi[i] = out_size - in_size + wt_size - padding_[i] - 1;
|
||||
}
|
||||
|
||||
auto in_trans = swapaxes(in, 0, -1, stream());
|
||||
auto cotan_trans = swapaxes(cotan, 0, -1, stream());
|
||||
auto in_trans = group_transpose(in, -1, 0, -1);
|
||||
|
||||
auto grad_trans = conv_general(
|
||||
/* const array& input = */ in_trans,
|
||||
/* const array& weight = */ cotan_trans,
|
||||
@@ -1069,11 +1083,10 @@ std::vector<array> Convolution::vjp(
|
||||
/* std::vector<int> padding_hi = */ padding_hi,
|
||||
/* std::vector<int> kernel_dilation = */ kernel_strides_,
|
||||
/* std::vector<int> input_dilation = */ input_dilation_,
|
||||
/* int groups = */ 1,
|
||||
/* int groups = */ groups_,
|
||||
/* bool flip = */ false,
|
||||
stream());
|
||||
auto grad = swapaxes(grad_trans, 0, -1, stream());
|
||||
grads.push_back(grad);
|
||||
grads.push_back(swapaxes(grad_trans, 0, -1, stream()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user