mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-31 07:12:20 +08:00
Remove unused code in Convolution::vjp (#2408)
This commit is contained in:
parent
28d068bce6
commit
588854195f
@ -1271,19 +1271,6 @@ std::vector<array> Convolution::vjp(
|
|||||||
has_neg_padding |= (pd < 0);
|
has_neg_padding |= (pd < 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto padding_lo_ = std::vector<int>(padding_lo);
|
|
||||||
auto padding_hi_ = std::vector<int>(padding_hi);
|
|
||||||
|
|
||||||
// Use negative padding on the gradient output
|
|
||||||
if (has_neg_padding) {
|
|
||||||
for (auto& p : padding_lo_) {
|
|
||||||
p = std::max(0, p);
|
|
||||||
}
|
|
||||||
for (auto& p : padding_hi_) {
|
|
||||||
p = std::max(0, p);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
auto wt_trans = group_transpose(wt, 0, 1, -1);
|
auto wt_trans = group_transpose(wt, 0, 1, -1);
|
||||||
auto grad = conv_general(
|
auto grad = conv_general(
|
||||||
/* const array& input = */ cotan,
|
/* const array& input = */ cotan,
|
||||||
@ -1305,12 +1292,9 @@ std::vector<array> Convolution::vjp(
|
|||||||
for (int i = 0; i < grad.ndim() - 2; i++) {
|
for (int i = 0; i < grad.ndim() - 2; i++) {
|
||||||
if (padding_lo[i] < 0) {
|
if (padding_lo[i] < 0) {
|
||||||
starts[i + 1] -= padding_lo[i];
|
starts[i + 1] -= padding_lo[i];
|
||||||
padding_lo[i] = 0;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (padding_hi[i] < 0) {
|
if (padding_hi[i] < 0) {
|
||||||
stops[i + 1] += padding_hi[i];
|
stops[i + 1] += padding_hi[i];
|
||||||
padding_hi[i] = 0;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user