mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 18:48:15 +08:00 
			
		
		
		
	Convolution update (#651)
* Init steel conv and update Conv primitive * Update slow CPU implementation to support flipping and input dilation winograd conv routing Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
		@@ -1,4 +1,4 @@
 | 
			
		||||
# Copyright © 2023 Apple Inc.
 | 
			
		||||
# Copyright © 2023-2024 Apple Inc.
 | 
			
		||||
 | 
			
		||||
import math
 | 
			
		||||
import unittest
 | 
			
		||||
@@ -388,13 +388,8 @@ class TestConv(mlx_tests.MLXTestCase):
 | 
			
		||||
 | 
			
		||||
                _, outs_mx = mx.vjp(
 | 
			
		||||
                    f,
 | 
			
		||||
                    [
 | 
			
		||||
                        in_mx,
 | 
			
		||||
                        wt_mx,
 | 
			
		||||
                    ],
 | 
			
		||||
                    [
 | 
			
		||||
                        ct_mx,
 | 
			
		||||
                    ],
 | 
			
		||||
                    [in_mx, wt_mx],
 | 
			
		||||
                    [ct_mx],
 | 
			
		||||
                )
 | 
			
		||||
                pt_grad_in = F.grad.conv1d_input(
 | 
			
		||||
                    in_pt.shape,
 | 
			
		||||
@@ -428,18 +423,218 @@ class TestConv(mlx_tests.MLXTestCase):
 | 
			
		||||
                self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
 | 
			
		||||
 | 
			
		||||
        for dtype in ("float32",):
 | 
			
		||||
            for N, C, O in (
 | 
			
		||||
                (1, 1, 1),
 | 
			
		||||
                (1, 6, 1),
 | 
			
		||||
                (1, 1, 6),
 | 
			
		||||
                (4, 32, 64),
 | 
			
		||||
            ):
 | 
			
		||||
                for idim, kdim, stride, padding in (
 | 
			
		||||
                    ((1, 1), (1, 1), (1, 1), (0, 0)),
 | 
			
		||||
                    ((3, 3), (3, 1), (1, 1), (0, 0)),
 | 
			
		||||
                    ((31, 31), (5, 5), (5, 5), (2, 2)),
 | 
			
		||||
            for N, C, O in ((1, 1, 1), (1, 6, 1), (1, 1, 6), (4, 32, 64), (4, 16, 32)):
 | 
			
		||||
                for idim, kdim, stride, padding, dilation in (
 | 
			
		||||
                    ((1, 1), (1, 1), (1, 1), (0, 0), (1, 1)),
 | 
			
		||||
                    ((3, 3), (3, 1), (1, 1), (0, 0), (1, 1)),
 | 
			
		||||
                    ((31, 31), (5, 5), (5, 5), (2, 2), (1, 1)),
 | 
			
		||||
                    ((32, 32), (3, 3), (2, 2), (1, 1), (1, 1)),
 | 
			
		||||
                    ((31, 31), (5, 5), (5, 5), (2, 2), (3, 2)),
 | 
			
		||||
                    ((32, 32), (3, 3), (2, 2), (1, 1), (3, 2)),
 | 
			
		||||
                ):
 | 
			
		||||
                    run_conv2D_grad(N, C, O, idim, kdim, stride, padding, dtype=dtype)
 | 
			
		||||
                    run_conv2D_grad(
 | 
			
		||||
                        N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
    def __conv_general_test(
 | 
			
		||||
        self,
 | 
			
		||||
        in_shape,
 | 
			
		||||
        wt_shape,
 | 
			
		||||
        stride=1,
 | 
			
		||||
        padding=0,
 | 
			
		||||
        kernel_dilation=1,
 | 
			
		||||
        input_dilation=1,
 | 
			
		||||
        groups=1,
 | 
			
		||||
        flip=False,
 | 
			
		||||
        np_dtype=np.float32,
 | 
			
		||||
        atol=1e-5,
 | 
			
		||||
    ):
 | 
			
		||||
 | 
			
		||||
        with self.subTest(
 | 
			
		||||
            in_shape=in_shape,
 | 
			
		||||
            wt_shape=wt_shape,
 | 
			
		||||
            stride=stride,
 | 
			
		||||
            padding=padding,
 | 
			
		||||
            kernel_dilation=kernel_dilation,
 | 
			
		||||
            input_dilation=input_dilation,
 | 
			
		||||
            groups=groups,
 | 
			
		||||
            flip=flip,
 | 
			
		||||
            np_dtype=np_dtype,
 | 
			
		||||
        ):
 | 
			
		||||
 | 
			
		||||
            scale = 1.0 / math.sqrt(np.prod(wt_shape[1:]))
 | 
			
		||||
            in_np = np.random.normal(0.0, scale, in_shape).astype(np_dtype)
 | 
			
		||||
            wt_np = np.random.normal(0.0, scale, wt_shape).astype(np_dtype)
 | 
			
		||||
 | 
			
		||||
            in_mx, wt_mx = map(mx.array, (in_np, wt_np))
 | 
			
		||||
 | 
			
		||||
            in_pt, wt_pt = map(
 | 
			
		||||
                lambda x: torch.from_numpy(np.moveaxis(x, -1, 1)).to("cpu"),
 | 
			
		||||
                (in_np, wt_np),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            out_mx = mx.conv_general(
 | 
			
		||||
                in_mx,
 | 
			
		||||
                wt_mx,
 | 
			
		||||
                stride=stride,
 | 
			
		||||
                padding=padding,
 | 
			
		||||
                kernel_dilation=kernel_dilation,
 | 
			
		||||
                input_dilation=input_dilation,
 | 
			
		||||
                groups=groups,
 | 
			
		||||
                flip=flip,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            def conv_general_pt(
 | 
			
		||||
                inp, wt, stride, padding, kernel_dilation, input_dilation, groups, flip
 | 
			
		||||
            ):
 | 
			
		||||
 | 
			
		||||
                C = inp.size()[1]
 | 
			
		||||
                ndim = inp.ndim - 2
 | 
			
		||||
                map_ints = lambda x: [x] * ndim if isinstance(x, int) else x
 | 
			
		||||
 | 
			
		||||
                stride, padding, kernel_dilation, input_dilation = map(
 | 
			
		||||
                    map_ints, (stride, padding, kernel_dilation, input_dilation)
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                torch_convt_list = (
 | 
			
		||||
                    F.conv_transpose1d,
 | 
			
		||||
                    F.conv_transpose2d,
 | 
			
		||||
                    F.conv_transpose3d,
 | 
			
		||||
                )
 | 
			
		||||
                torch_conv_list = (F.conv1d, F.conv2d, F.conv3d)
 | 
			
		||||
 | 
			
		||||
                conv_f = torch_conv_list[ndim - 1]
 | 
			
		||||
                convt_f = torch_convt_list[ndim - 1]
 | 
			
		||||
 | 
			
		||||
                if flip:
 | 
			
		||||
                    wt = torch.flip(wt, tuple(np.arange(2, wt.ndim)))
 | 
			
		||||
 | 
			
		||||
                if not np.all(input_dilation == 1):
 | 
			
		||||
                    ones = torch.ones(
 | 
			
		||||
                        [C]
 | 
			
		||||
                        + [
 | 
			
		||||
                            1,
 | 
			
		||||
                        ]
 | 
			
		||||
                        * (ndim + 1)
 | 
			
		||||
                    ).to(inp.dtype)
 | 
			
		||||
                    inp = convt_f(inp, ones, stride=input_dilation, groups=C)
 | 
			
		||||
 | 
			
		||||
                return conv_f(
 | 
			
		||||
                    inp,
 | 
			
		||||
                    wt,
 | 
			
		||||
                    stride=stride,
 | 
			
		||||
                    padding=padding,
 | 
			
		||||
                    dilation=kernel_dilation,
 | 
			
		||||
                    groups=groups,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            out_pt = conv_general_pt(
 | 
			
		||||
                in_pt,
 | 
			
		||||
                wt_pt,
 | 
			
		||||
                stride=stride,
 | 
			
		||||
                padding=padding,
 | 
			
		||||
                kernel_dilation=kernel_dilation,
 | 
			
		||||
                input_dilation=input_dilation,
 | 
			
		||||
                groups=groups,
 | 
			
		||||
                flip=flip,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            out_pt = np.moveaxis(out_pt.numpy(), 1, -1)
 | 
			
		||||
 | 
			
		||||
            self.assertEqual(out_mx.shape, out_pt.shape)
 | 
			
		||||
            self.assertTrue(np.allclose(out_mx, out_pt, atol=atol))
 | 
			
		||||
 | 
			
		||||
    @unittest.skipIf(not has_torch, "requires Torch")
 | 
			
		||||
    def test_torch_conv_general(self):
 | 
			
		||||
        in_shape = (2, 32, 32, 16)
 | 
			
		||||
        wt_shape = (32, 5, 5, 16)
 | 
			
		||||
        stride = (1, 1)
 | 
			
		||||
        padding = (2, 2)
 | 
			
		||||
        kernel_dilation = (2, 3)
 | 
			
		||||
        input_dilation = (1, 1)
 | 
			
		||||
        flip = False
 | 
			
		||||
 | 
			
		||||
        self.__conv_general_test(
 | 
			
		||||
            in_shape,
 | 
			
		||||
            wt_shape,
 | 
			
		||||
            stride,
 | 
			
		||||
            padding,
 | 
			
		||||
            kernel_dilation,
 | 
			
		||||
            input_dilation,
 | 
			
		||||
            flip=flip,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        in_shape = (2, 32, 32, 16)
 | 
			
		||||
        wt_shape = (32, 5, 10, 16)
 | 
			
		||||
        stride = (2, 3)
 | 
			
		||||
        padding = (0, 0)
 | 
			
		||||
        kernel_dilation = (3, 2)
 | 
			
		||||
        input_dilation = (2, 4)
 | 
			
		||||
        flip = False
 | 
			
		||||
 | 
			
		||||
        self.__conv_general_test(
 | 
			
		||||
            in_shape,
 | 
			
		||||
            wt_shape,
 | 
			
		||||
            stride,
 | 
			
		||||
            padding,
 | 
			
		||||
            kernel_dilation,
 | 
			
		||||
            input_dilation,
 | 
			
		||||
            flip=flip,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        in_shape = (2, 32, 32, 16)
 | 
			
		||||
        wt_shape = (32, 5, 10, 16)
 | 
			
		||||
        stride = (2, 2)
 | 
			
		||||
        padding = (3, 2)
 | 
			
		||||
        kernel_dilation = (3, 2)
 | 
			
		||||
        input_dilation = (2, 4)
 | 
			
		||||
        flip = False
 | 
			
		||||
 | 
			
		||||
        self.__conv_general_test(
 | 
			
		||||
            in_shape,
 | 
			
		||||
            wt_shape,
 | 
			
		||||
            stride,
 | 
			
		||||
            padding,
 | 
			
		||||
            kernel_dilation,
 | 
			
		||||
            input_dilation,
 | 
			
		||||
            flip=flip,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        in_shape = (2, 32, 32, 16)
 | 
			
		||||
        wt_shape = (32, 5, 10, 16)
 | 
			
		||||
        stride = (2, 3)
 | 
			
		||||
        padding = (3, 2)
 | 
			
		||||
        kernel_dilation = (3, 2)
 | 
			
		||||
        input_dilation = (2, 5)
 | 
			
		||||
        flip = False
 | 
			
		||||
 | 
			
		||||
        self.__conv_general_test(
 | 
			
		||||
            in_shape,
 | 
			
		||||
            wt_shape,
 | 
			
		||||
            stride,
 | 
			
		||||
            padding,
 | 
			
		||||
            kernel_dilation,
 | 
			
		||||
            input_dilation,
 | 
			
		||||
            flip=flip,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        in_shape = (2, 32, 32, 16)
 | 
			
		||||
        wt_shape = (32, 5, 5, 16)
 | 
			
		||||
        stride = (2, 3)
 | 
			
		||||
        padding = (0, 0)
 | 
			
		||||
        kernel_dilation = (3, 1)
 | 
			
		||||
        input_dilation = (2, 5)
 | 
			
		||||
        flip = True
 | 
			
		||||
 | 
			
		||||
        self.__conv_general_test(
 | 
			
		||||
            in_shape,
 | 
			
		||||
            wt_shape,
 | 
			
		||||
            stride,
 | 
			
		||||
            padding,
 | 
			
		||||
            kernel_dilation,
 | 
			
		||||
            input_dilation,
 | 
			
		||||
            flip=flip,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user