fix conv grad (#2187)

This commit is contained in:
Awni Hannun 2025-05-15 19:20:36 -07:00 committed by GitHub
parent a2cadb8218
commit 602f43e3d1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 33 additions and 7 deletions

View File

@ -1116,13 +1116,11 @@ array conv_weight_backward_patches(
// Pad input // Pad input
std::vector<int> padded_axes(in.ndim() - 2, 0); std::vector<int> padded_axes(in.ndim() - 2, 0);
std::iota(padded_axes.begin(), padded_axes.end(), 1); std::iota(padded_axes.begin(), padded_axes.end(), 1);
Shape padding_lo_(padding_lo.begin(), padding_lo.end());
Shape padding_hi_(padding_hi.begin(), padding_hi.end());
auto in_padded = auto in_padded =
pad(in, pad(in,
padded_axes, padded_axes,
padding_lo_, Shape(padding_lo),
padding_hi_, Shape(padding_hi),
array(0, in.dtype()), array(0, in.dtype()),
"constant", "constant",
s); s);
@ -1274,8 +1272,14 @@ std::vector<array> Convolution::vjp(
in, wt, cotan, kernel_strides_, padding_lo_, padding_hi_, stream()); in, wt, cotan, kernel_strides_, padding_lo_, padding_hi_, stream());
grads.push_back(grad); grads.push_back(grad);
} else { } else {
std::vector<int> padding_lo = padding_lo_; auto padding_hi = padding_lo_;
std::vector<int> padding_hi = padding_hi_;
for (int i = 0; i < padding_hi.size(); ++i) {
int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1);
int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1);
int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1);
padding_hi[i] = out_size - in_size + wt_size - padding_hi[i] - 1;
}
auto cotan_trans = swapaxes(cotan, 0, -1, stream()); auto cotan_trans = swapaxes(cotan, 0, -1, stream());
auto in_trans = group_transpose(in, -1, 0, -1); auto in_trans = group_transpose(in, -1, 0, -1);
@ -1284,7 +1288,7 @@ std::vector<array> Convolution::vjp(
/* const array& input = */ in_trans, /* const array& input = */ in_trans,
/* const array& weight = */ cotan_trans, /* const array& weight = */ cotan_trans,
/* std::vector<int> stride = */ kernel_dilation_, /* std::vector<int> stride = */ kernel_dilation_,
/* std::vector<int> padding_lo = */ padding_lo, /* std::vector<int> padding_lo = */ padding_lo_,
/* std::vector<int> padding_hi = */ padding_hi, /* std::vector<int> padding_hi = */ padding_hi,
/* std::vector<int> kernel_dilation = */ kernel_strides_, /* std::vector<int> kernel_dilation = */ kernel_strides_,
/* std::vector<int> input_dilation = */ input_dilation_, /* std::vector<int> input_dilation = */ input_dilation_,

View File

@ -1130,6 +1130,28 @@ class TestConv(mlx_tests.MLXTestCase):
) )
self.assertTrue(mx.allclose(mx_out, mx.array(pt_out), atol=1e-3, rtol=1e-3)) self.assertTrue(mx.allclose(mx_out, mx.array(pt_out), atol=1e-3, rtol=1e-3))
def test_basic_grad_shapes(self):
def loss_fn(kernel, inputs, strides, groups):
return mx.sum(
mx.conv_general(
inputs,
kernel,
stride=strides,
groups=groups,
)
)
for in_shape, k_shape, strides, groups in [
((3, 5, 4), (6, 2, 2), (2,), 2),
((3, 5, 4), (24, 2, 1), (2,), 4),
((3, 5, 5, 4), (6, 2, 2, 2), (2, 1), 2),
((3, 5, 5, 4), (24, 2, 2, 1), (2, 2), 4),
]:
grads = mx.grad(loss_fn)(
mx.zeros(k_shape), mx.zeros(in_shape), strides, groups
)
self.assertEqual(grads.shape, k_shape)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()