From 602f43e3d1f75a1036a3008024afa8f27c3140d7 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 15 May 2025 19:20:36 -0700 Subject: [PATCH] fix conv grad (#2187) --- mlx/primitives.cpp | 18 +++++++++++------- python/tests/test_conv.py | 22 ++++++++++++++++++++++ 2 files changed, 33 insertions(+), 7 deletions(-) diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 87b2bc924..c2bb59c05 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -1116,13 +1116,11 @@ array conv_weight_backward_patches( // Pad input std::vector padded_axes(in.ndim() - 2, 0); std::iota(padded_axes.begin(), padded_axes.end(), 1); - Shape padding_lo_(padding_lo.begin(), padding_lo.end()); - Shape padding_hi_(padding_hi.begin(), padding_hi.end()); auto in_padded = pad(in, padded_axes, - padding_lo_, - padding_hi_, + Shape(padding_lo), + Shape(padding_hi), array(0, in.dtype()), "constant", s); @@ -1274,8 +1272,14 @@ std::vector Convolution::vjp( in, wt, cotan, kernel_strides_, padding_lo_, padding_hi_, stream()); grads.push_back(grad); } else { - std::vector padding_lo = padding_lo_; - std::vector padding_hi = padding_hi_; + auto padding_hi = padding_lo_; + + for (int i = 0; i < padding_hi.size(); ++i) { + int in_size = 1 + input_dilation_[i] * (in.shape(1 + i) - 1); + int out_size = 1 + kernel_strides_[i] * (cotan.shape(1 + i) - 1); + int wt_size = 1 + kernel_dilation_[i] * (wt.shape(1 + i) - 1); + padding_hi[i] = out_size - in_size + wt_size - padding_hi[i] - 1; + } auto cotan_trans = swapaxes(cotan, 0, -1, stream()); auto in_trans = group_transpose(in, -1, 0, -1); @@ -1284,7 +1288,7 @@ std::vector Convolution::vjp( /* const array& input = */ in_trans, /* const array& weight = */ cotan_trans, /* std::vector stride = */ kernel_dilation_, - /* std::vector padding_lo = */ padding_lo, + /* std::vector padding_lo = */ padding_lo_, /* std::vector padding_hi = */ padding_hi, /* std::vector kernel_dilation = */ kernel_strides_, /* std::vector input_dilation = */ input_dilation_, diff --git a/python/tests/test_conv.py b/python/tests/test_conv.py index 35dcf42ac..7d63e4751 100644 --- a/python/tests/test_conv.py +++ b/python/tests/test_conv.py @@ -1130,6 +1130,28 @@ class TestConv(mlx_tests.MLXTestCase): ) self.assertTrue(mx.allclose(mx_out, mx.array(pt_out), atol=1e-3, rtol=1e-3)) + def test_basic_grad_shapes(self): + def loss_fn(kernel, inputs, strides, groups): + return mx.sum( + mx.conv_general( + inputs, + kernel, + stride=strides, + groups=groups, + ) + ) + + for in_shape, k_shape, strides, groups in [ + ((3, 5, 4), (6, 2, 2), (2,), 2), + ((3, 5, 4), (24, 2, 1), (2,), 4), + ((3, 5, 5, 4), (6, 2, 2, 2), (2, 1), 2), + ((3, 5, 5, 4), (24, 2, 2, 1), (2, 2), 4), + ]: + grads = mx.grad(loss_fn)( + mx.zeros(k_shape), mx.zeros(in_shape), strides, groups + ) + self.assertEqual(grads.shape, k_shape) + if __name__ == "__main__": unittest.main()