diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 889d68aa00..03ca06bdd3 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -1240,12 +1240,6 @@ std::vector Convolution::vjp( std::vector padding_lo = padding_lo_; std::vector padding_hi = padding_hi_; - for (int i = 0; i < padding_hi.size(); ++i) { - int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1); - int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1); - int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1); - padding_hi[i] = out_size - in_size + wt_size - padding_hi_[i] - 1; - } auto cotan_trans = swapaxes(cotan, 0, -1, stream()); auto in_trans = group_transpose(in, -1, 0, -1); @@ -1291,7 +1285,8 @@ std::pair, std::vector> Convolution::vmap( in, w, kernel_strides_, - padding_, + padding_lo_, + padding_hi_, kernel_dilation_, input_dilation_, groups,