diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index 720e2c546..35ed3d44e 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -944,8 +944,6 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out) { wt = arr_copy; } - auto padding_ = padding_lo_; - // 3D conv if (out.ndim() == 5) { conv_3D_gpu( @@ -954,7 +952,7 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out) { in, wt, out, - padding_, + padding_lo_, kernel_strides_, kernel_dilation_, input_dilation_, @@ -969,7 +967,7 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out) { in, wt, out, - padding_, + padding_lo_, kernel_strides_, kernel_dilation_, input_dilation_, @@ -985,7 +983,7 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out) { in, wt, out, - padding_, + padding_lo_, kernel_strides_, kernel_dilation_, input_dilation_, diff --git a/python/tests/test_conv.py b/python/tests/test_conv.py index 671c86a32..35dcf42ac 100644 --- a/python/tests/test_conv.py +++ b/python/tests/test_conv.py @@ -1088,6 +1088,48 @@ class TestConv(mlx_tests.MLXTestCase): atol=2e-5 if dtype == np.float32 else 5e-4, ) + @unittest.skipIf(not has_torch, "requires Torch") + def test_asymmetric_padding(self): + inputs = np.random.normal(size=(2, 8, 8, 8, 3)).astype(np.float32) + kernel = np.random.normal(size=(2, 3, 3, 3, 3)).astype(np.float32) + strides = (2, 2, 2) + + pt_out = torch.conv3d( + torch.permute(torch.tensor(inputs), (0, 4, 1, 2, 3)), + torch.permute(torch.tensor(kernel), (0, 4, 1, 2, 3)), + stride=strides, + padding=2, + ) + pt_out = torch.permute(pt_out, (0, 2, 3, 4, 1))[:, 1:, 1:, 1:, :].numpy() + + mx_out = mx.conv_general( + mx.array(inputs), + mx.array(kernel), + stride=strides, + padding=([0, 0, 0], [1, 1, 1]), + ) + + self.assertTrue(mx.allclose(mx_out, mx.array(pt_out), atol=1e-3, rtol=1e-3)) + + inputs = np.random.normal(size=(2, 10, 10, 3)).astype(np.float32) + kernel = np.random.normal(size=(2, 2, 2, 3)).astype(np.float32) + + pt_out = torch.conv2d( + torch.permute(torch.tensor(inputs), (0, 3, 1, 2)), + torch.permute(torch.tensor(kernel), (0, 3, 1, 2)), + stride=1, + padding=(1, 0), + ) + pt_out = torch.permute(pt_out, (0, 2, 3, 1))[:, 1:].numpy() + + mx_out = mx.conv_general( + mx.array(inputs), + mx.array(kernel), + stride=1, + padding=([0, 0], [1, 0]), + ) + self.assertTrue(mx.allclose(mx_out, mx.array(pt_out), atol=1e-3, rtol=1e-3)) + if __name__ == "__main__": unittest.main()