This commit is contained in:
Awni Hannun 2025-05-09 09:54:20 -07:00
parent 76def90b73
commit 88cc8e0755
2 changed files with 45 additions and 5 deletions

View File

@ -944,8 +944,6 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
wt = arr_copy; wt = arr_copy;
} }
auto padding_ = padding_lo_;
// 3D conv // 3D conv
if (out.ndim() == 5) { if (out.ndim() == 5) {
conv_3D_gpu( conv_3D_gpu(
@ -954,7 +952,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
in, in,
wt, wt,
out, out,
padding_, padding_lo_,
kernel_strides_, kernel_strides_,
kernel_dilation_, kernel_dilation_,
input_dilation_, input_dilation_,
@ -969,7 +967,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
in, in,
wt, wt,
out, out,
padding_, padding_lo_,
kernel_strides_, kernel_strides_,
kernel_dilation_, kernel_dilation_,
input_dilation_, input_dilation_,
@ -985,7 +983,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
in, in,
wt, wt,
out, out,
padding_, padding_lo_,
kernel_strides_, kernel_strides_,
kernel_dilation_, kernel_dilation_,
input_dilation_, input_dilation_,

View File

@ -1088,6 +1088,48 @@ class TestConv(mlx_tests.MLXTestCase):
atol=2e-5 if dtype == np.float32 else 5e-4, atol=2e-5 if dtype == np.float32 else 5e-4,
) )
@unittest.skipIf(not has_torch, "requires Torch")
def test_asymmetric_padding(self):
inputs = np.random.normal(size=(2, 8, 8, 8, 3)).astype(np.float32)
kernel = np.random.normal(size=(2, 3, 3, 3, 3)).astype(np.float32)
strides = (2, 2, 2)
pt_out = torch.conv3d(
torch.permute(torch.tensor(inputs), (0, 4, 1, 2, 3)),
torch.permute(torch.tensor(kernel), (0, 4, 1, 2, 3)),
stride=strides,
padding=2,
)
pt_out = torch.permute(pt_out, (0, 2, 3, 4, 1))[:, 1:, 1:, 1:, :].numpy()
mx_out = mx.conv_general(
mx.array(inputs),
mx.array(kernel),
stride=strides,
padding=([0, 0, 0], [1, 1, 1]),
)
self.assertTrue(mx.allclose(mx_out, mx.array(pt_out), atol=1e-3, rtol=1e-3))
inputs = np.random.normal(size=(2, 10, 10, 3)).astype(np.float32)
kernel = np.random.normal(size=(2, 2, 2, 3)).astype(np.float32)
pt_out = torch.conv2d(
torch.permute(torch.tensor(inputs), (0, 3, 1, 2)),
torch.permute(torch.tensor(kernel), (0, 3, 1, 2)),
stride=1,
padding=(1, 0),
)
pt_out = torch.permute(pt_out, (0, 2, 3, 1))[:, 1:].numpy()
mx_out = mx.conv_general(
mx.array(inputs),
mx.array(kernel),
stride=1,
padding=([0, 0], [1, 0]),
)
self.assertTrue(mx.allclose(mx_out, mx.array(pt_out), atol=1e-3, rtol=1e-3))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()