mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	Convolution update (#651)
* Init steel conv and update Conv primitive * Update slow CPU implementation to support flipping and input dilation winograd conv routing Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
		| @@ -1,4 +1,4 @@ | ||||
| # Copyright © 2023 Apple Inc. | ||||
| # Copyright © 2023-2024 Apple Inc. | ||||
|  | ||||
| import math | ||||
| import unittest | ||||
| @@ -388,13 +388,8 @@ class TestConv(mlx_tests.MLXTestCase): | ||||
|  | ||||
|                 _, outs_mx = mx.vjp( | ||||
|                     f, | ||||
|                     [ | ||||
|                         in_mx, | ||||
|                         wt_mx, | ||||
|                     ], | ||||
|                     [ | ||||
|                         ct_mx, | ||||
|                     ], | ||||
|                     [in_mx, wt_mx], | ||||
|                     [ct_mx], | ||||
|                 ) | ||||
|                 pt_grad_in = F.grad.conv1d_input( | ||||
|                     in_pt.shape, | ||||
| @@ -428,18 +423,218 @@ class TestConv(mlx_tests.MLXTestCase): | ||||
|                 self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol)) | ||||
|  | ||||
|         for dtype in ("float32",): | ||||
|             for N, C, O in ( | ||||
|                 (1, 1, 1), | ||||
|                 (1, 6, 1), | ||||
|                 (1, 1, 6), | ||||
|                 (4, 32, 64), | ||||
|             ): | ||||
|                 for idim, kdim, stride, padding in ( | ||||
|                     ((1, 1), (1, 1), (1, 1), (0, 0)), | ||||
|                     ((3, 3), (3, 1), (1, 1), (0, 0)), | ||||
|                     ((31, 31), (5, 5), (5, 5), (2, 2)), | ||||
|             for N, C, O in ((1, 1, 1), (1, 6, 1), (1, 1, 6), (4, 32, 64), (4, 16, 32)): | ||||
|                 for idim, kdim, stride, padding, dilation in ( | ||||
|                     ((1, 1), (1, 1), (1, 1), (0, 0), (1, 1)), | ||||
|                     ((3, 3), (3, 1), (1, 1), (0, 0), (1, 1)), | ||||
|                     ((31, 31), (5, 5), (5, 5), (2, 2), (1, 1)), | ||||
|                     ((32, 32), (3, 3), (2, 2), (1, 1), (1, 1)), | ||||
|                     ((31, 31), (5, 5), (5, 5), (2, 2), (3, 2)), | ||||
|                     ((32, 32), (3, 3), (2, 2), (1, 1), (3, 2)), | ||||
|                 ): | ||||
|                     run_conv2D_grad(N, C, O, idim, kdim, stride, padding, dtype=dtype) | ||||
|                     run_conv2D_grad( | ||||
|                         N, C, O, idim, kdim, stride, padding, dilation, dtype=dtype | ||||
|                     ) | ||||
|  | ||||
|     def __conv_general_test( | ||||
|         self, | ||||
|         in_shape, | ||||
|         wt_shape, | ||||
|         stride=1, | ||||
|         padding=0, | ||||
|         kernel_dilation=1, | ||||
|         input_dilation=1, | ||||
|         groups=1, | ||||
|         flip=False, | ||||
|         np_dtype=np.float32, | ||||
|         atol=1e-5, | ||||
|     ): | ||||
|  | ||||
|         with self.subTest( | ||||
|             in_shape=in_shape, | ||||
|             wt_shape=wt_shape, | ||||
|             stride=stride, | ||||
|             padding=padding, | ||||
|             kernel_dilation=kernel_dilation, | ||||
|             input_dilation=input_dilation, | ||||
|             groups=groups, | ||||
|             flip=flip, | ||||
|             np_dtype=np_dtype, | ||||
|         ): | ||||
|  | ||||
|             scale = 1.0 / math.sqrt(np.prod(wt_shape[1:])) | ||||
|             in_np = np.random.normal(0.0, scale, in_shape).astype(np_dtype) | ||||
|             wt_np = np.random.normal(0.0, scale, wt_shape).astype(np_dtype) | ||||
|  | ||||
|             in_mx, wt_mx = map(mx.array, (in_np, wt_np)) | ||||
|  | ||||
|             in_pt, wt_pt = map( | ||||
|                 lambda x: torch.from_numpy(np.moveaxis(x, -1, 1)).to("cpu"), | ||||
|                 (in_np, wt_np), | ||||
|             ) | ||||
|  | ||||
|             out_mx = mx.conv_general( | ||||
|                 in_mx, | ||||
|                 wt_mx, | ||||
|                 stride=stride, | ||||
|                 padding=padding, | ||||
|                 kernel_dilation=kernel_dilation, | ||||
|                 input_dilation=input_dilation, | ||||
|                 groups=groups, | ||||
|                 flip=flip, | ||||
|             ) | ||||
|  | ||||
|             def conv_general_pt( | ||||
|                 inp, wt, stride, padding, kernel_dilation, input_dilation, groups, flip | ||||
|             ): | ||||
|  | ||||
|                 C = inp.size()[1] | ||||
|                 ndim = inp.ndim - 2 | ||||
|                 map_ints = lambda x: [x] * ndim if isinstance(x, int) else x | ||||
|  | ||||
|                 stride, padding, kernel_dilation, input_dilation = map( | ||||
|                     map_ints, (stride, padding, kernel_dilation, input_dilation) | ||||
|                 ) | ||||
|  | ||||
|                 torch_convt_list = ( | ||||
|                     F.conv_transpose1d, | ||||
|                     F.conv_transpose2d, | ||||
|                     F.conv_transpose3d, | ||||
|                 ) | ||||
|                 torch_conv_list = (F.conv1d, F.conv2d, F.conv3d) | ||||
|  | ||||
|                 conv_f = torch_conv_list[ndim - 1] | ||||
|                 convt_f = torch_convt_list[ndim - 1] | ||||
|  | ||||
|                 if flip: | ||||
|                     wt = torch.flip(wt, tuple(np.arange(2, wt.ndim))) | ||||
|  | ||||
|                 if not np.all(input_dilation == 1): | ||||
|                     ones = torch.ones( | ||||
|                         [C] | ||||
|                         + [ | ||||
|                             1, | ||||
|                         ] | ||||
|                         * (ndim + 1) | ||||
|                     ).to(inp.dtype) | ||||
|                     inp = convt_f(inp, ones, stride=input_dilation, groups=C) | ||||
|  | ||||
|                 return conv_f( | ||||
|                     inp, | ||||
|                     wt, | ||||
|                     stride=stride, | ||||
|                     padding=padding, | ||||
|                     dilation=kernel_dilation, | ||||
|                     groups=groups, | ||||
|                 ) | ||||
|  | ||||
|             out_pt = conv_general_pt( | ||||
|                 in_pt, | ||||
|                 wt_pt, | ||||
|                 stride=stride, | ||||
|                 padding=padding, | ||||
|                 kernel_dilation=kernel_dilation, | ||||
|                 input_dilation=input_dilation, | ||||
|                 groups=groups, | ||||
|                 flip=flip, | ||||
|             ) | ||||
|  | ||||
|             out_pt = np.moveaxis(out_pt.numpy(), 1, -1) | ||||
|  | ||||
|             self.assertEqual(out_mx.shape, out_pt.shape) | ||||
|             self.assertTrue(np.allclose(out_mx, out_pt, atol=atol)) | ||||
|  | ||||
|     @unittest.skipIf(not has_torch, "requires Torch") | ||||
|     def test_torch_conv_general(self): | ||||
|         in_shape = (2, 32, 32, 16) | ||||
|         wt_shape = (32, 5, 5, 16) | ||||
|         stride = (1, 1) | ||||
|         padding = (2, 2) | ||||
|         kernel_dilation = (2, 3) | ||||
|         input_dilation = (1, 1) | ||||
|         flip = False | ||||
|  | ||||
|         self.__conv_general_test( | ||||
|             in_shape, | ||||
|             wt_shape, | ||||
|             stride, | ||||
|             padding, | ||||
|             kernel_dilation, | ||||
|             input_dilation, | ||||
|             flip=flip, | ||||
|         ) | ||||
|  | ||||
|         in_shape = (2, 32, 32, 16) | ||||
|         wt_shape = (32, 5, 10, 16) | ||||
|         stride = (2, 3) | ||||
|         padding = (0, 0) | ||||
|         kernel_dilation = (3, 2) | ||||
|         input_dilation = (2, 4) | ||||
|         flip = False | ||||
|  | ||||
|         self.__conv_general_test( | ||||
|             in_shape, | ||||
|             wt_shape, | ||||
|             stride, | ||||
|             padding, | ||||
|             kernel_dilation, | ||||
|             input_dilation, | ||||
|             flip=flip, | ||||
|         ) | ||||
|  | ||||
|         in_shape = (2, 32, 32, 16) | ||||
|         wt_shape = (32, 5, 10, 16) | ||||
|         stride = (2, 2) | ||||
|         padding = (3, 2) | ||||
|         kernel_dilation = (3, 2) | ||||
|         input_dilation = (2, 4) | ||||
|         flip = False | ||||
|  | ||||
|         self.__conv_general_test( | ||||
|             in_shape, | ||||
|             wt_shape, | ||||
|             stride, | ||||
|             padding, | ||||
|             kernel_dilation, | ||||
|             input_dilation, | ||||
|             flip=flip, | ||||
|         ) | ||||
|  | ||||
|         in_shape = (2, 32, 32, 16) | ||||
|         wt_shape = (32, 5, 10, 16) | ||||
|         stride = (2, 3) | ||||
|         padding = (3, 2) | ||||
|         kernel_dilation = (3, 2) | ||||
|         input_dilation = (2, 5) | ||||
|         flip = False | ||||
|  | ||||
|         self.__conv_general_test( | ||||
|             in_shape, | ||||
|             wt_shape, | ||||
|             stride, | ||||
|             padding, | ||||
|             kernel_dilation, | ||||
|             input_dilation, | ||||
|             flip=flip, | ||||
|         ) | ||||
|  | ||||
|         in_shape = (2, 32, 32, 16) | ||||
|         wt_shape = (32, 5, 5, 16) | ||||
|         stride = (2, 3) | ||||
|         padding = (0, 0) | ||||
|         kernel_dilation = (3, 1) | ||||
|         input_dilation = (2, 5) | ||||
|         flip = True | ||||
|  | ||||
|         self.__conv_general_test( | ||||
|             in_shape, | ||||
|             wt_shape, | ||||
|             stride, | ||||
|             padding, | ||||
|             kernel_dilation, | ||||
|             input_dilation, | ||||
|             flip=flip, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Jagrit Digani
					Jagrit Digani