mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-30 14:41:23 +08:00
fix bugs
This commit is contained in:
parent
7942191a64
commit
76def90b73
@ -1240,12 +1240,6 @@ std::vector<array> Convolution::vjp(
|
||||
std::vector<int> padding_lo = padding_lo_;
|
||||
std::vector<int> padding_hi = padding_hi_;
|
||||
|
||||
for (int i = 0; i < padding_hi.size(); ++i) {
|
||||
int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1);
|
||||
int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1);
|
||||
int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1);
|
||||
padding_hi[i] = out_size - in_size + wt_size - padding_hi_[i] - 1;
|
||||
}
|
||||
auto cotan_trans = swapaxes(cotan, 0, -1, stream());
|
||||
auto in_trans = group_transpose(in, -1, 0, -1);
|
||||
|
||||
@ -1291,7 +1285,8 @@ std::pair<std::vector<array>, std::vector<int>> Convolution::vmap(
|
||||
in,
|
||||
w,
|
||||
kernel_strides_,
|
||||
padding_,
|
||||
padding_lo_,
|
||||
padding_hi_,
|
||||
kernel_dilation_,
|
||||
input_dilation_,
|
||||
groups,
|
||||
|
Loading…
Reference in New Issue
Block a user