mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	add groups in conv2d (#1569)
This commit is contained in:
		| @@ -101,6 +101,8 @@ class Conv2d(Module): | ||||
|         padding (int or tuple, optional): How many positions to 0-pad | ||||
|             the input with. Default: ``0``. | ||||
|         dilation (int or tuple, optional): The dilation of the convolution. | ||||
|         groups (int, optional): The number of groups for the convolution. | ||||
|             Default: ``1``. | ||||
|         bias (bool, optional): If ``True`` add a learnable bias to the | ||||
|             output. Default: ``True`` | ||||
|     """ | ||||
| @@ -113,10 +115,17 @@ class Conv2d(Module): | ||||
|         stride: Union[int, tuple] = 1, | ||||
|         padding: Union[int, tuple] = 0, | ||||
|         dilation: Union[int, tuple] = 1, | ||||
|         groups: int = 1, | ||||
|         bias: bool = True, | ||||
|     ): | ||||
|         super().__init__() | ||||
|  | ||||
|         if in_channels % groups != 0: | ||||
|             raise ValueError( | ||||
|                 f"The number of input channels ({in_channels}) must be " | ||||
|                 f"divisible by the number of groups ({groups})" | ||||
|             ) | ||||
|  | ||||
|         kernel_size, stride, padding = map( | ||||
|             lambda x: (x, x) if isinstance(x, int) else x, | ||||
|             (kernel_size, stride, padding), | ||||
| @@ -125,7 +134,7 @@ class Conv2d(Module): | ||||
|         self.weight = mx.random.uniform( | ||||
|             low=-scale, | ||||
|             high=scale, | ||||
|             shape=(out_channels, *kernel_size, in_channels), | ||||
|             shape=(out_channels, *kernel_size, in_channels // groups), | ||||
|         ) | ||||
|         if bias: | ||||
|             self.bias = mx.zeros((out_channels,)) | ||||
| @@ -133,17 +142,21 @@ class Conv2d(Module): | ||||
|         self.padding = padding | ||||
|         self.stride = stride | ||||
|         self.dilation = dilation | ||||
|         self.groups = groups | ||||
|  | ||||
|     def _extra_repr(self): | ||||
|         return ( | ||||
|             f"{self.weight.shape[-1]}, {self.weight.shape[0]}, " | ||||
|             f"kernel_size={self.weight.shape[1:2]}, stride={self.stride}, " | ||||
|             f"padding={self.padding}, dilation={self.dilation}, " | ||||
|             f"groups={self.groups}, " | ||||
|             f"bias={'bias' in self}" | ||||
|         ) | ||||
|  | ||||
|     def __call__(self, x): | ||||
|         y = mx.conv2d(x, self.weight, self.stride, self.padding, self.dilation) | ||||
|         y = mx.conv2d( | ||||
|             x, self.weight, self.stride, self.padding, self.dilation, self.groups | ||||
|         ) | ||||
|         if "bias" in self: | ||||
|             y = y + self.bias | ||||
|         return y | ||||
|   | ||||
| @@ -706,6 +706,12 @@ class TestLayers(mlx_tests.MLXTestCase): | ||||
|         self.assertEqual(y.shape, (4, 4, 4, 8)) | ||||
|         self.assertLess(mx.abs(y - c.weight.sum((1, 2, 3))).max(), 1e-4) | ||||
|  | ||||
|         # 3x3 conv groups > 1 | ||||
|         x = mx.ones((4, 7, 7, 4)) | ||||
|         c = nn.Conv2d(4, 8, 3, padding=1, stride=1, groups=2) | ||||
|         y = c(x) | ||||
|         self.assertEqual(y.shape, (4, 7, 7, 8)) | ||||
|  | ||||
|     def test_sequential(self): | ||||
|         x = mx.ones((10, 2)) | ||||
|         m = nn.Sequential(nn.Linear(2, 10), nn.ReLU(), nn.Linear(10, 1)) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun