mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	Add support for grouped 1D convolutions to the nn API (#1444)
* Fix the weight shape for grouped convolutions from the nn API. * Add tests. * Pre-commit formatting. * Add input validation. * Use integer division instead of casting. * docs * nit --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
		| @@ -25,6 +25,8 @@ class Conv1d(Module): | ||||
|         padding (int, optional): How many positions to 0-pad the input with. | ||||
|             Default: ``0``. | ||||
|         dilation (int, optional): The dilation of the convolution. | ||||
|         groups (int, optional): The number of groups for the convolution. | ||||
|             Default: ``1``. | ||||
|         bias (bool, optional): If ``True`` add a learnable bias to the output. | ||||
|             Default: ``True`` | ||||
|     """ | ||||
| @@ -37,15 +39,22 @@ class Conv1d(Module): | ||||
|         stride: int = 1, | ||||
|         padding: int = 0, | ||||
|         dilation: int = 1, | ||||
|         groups: int = 1, | ||||
|         bias: bool = True, | ||||
|     ): | ||||
|         super().__init__() | ||||
|  | ||||
|         if in_channels % groups != 0: | ||||
|             raise ValueError( | ||||
|                 f"The number of input channels ({in_channels}) must be " | ||||
|                 f"divisible by the number of groups ({groups})" | ||||
|             ) | ||||
|  | ||||
|         scale = math.sqrt(1 / (in_channels * kernel_size)) | ||||
|         self.weight = mx.random.uniform( | ||||
|             low=-scale, | ||||
|             high=scale, | ||||
|             shape=(out_channels, kernel_size, in_channels), | ||||
|             shape=(out_channels, kernel_size, in_channels // groups), | ||||
|         ) | ||||
|         if bias: | ||||
|             self.bias = mx.zeros((out_channels,)) | ||||
| @@ -53,17 +62,21 @@ class Conv1d(Module): | ||||
|         self.padding = padding | ||||
|         self.dilation = dilation | ||||
|         self.stride = stride | ||||
|         self.groups = groups | ||||
|  | ||||
|     def _extra_repr(self): | ||||
|         return ( | ||||
|             f"{self.weight.shape[-1]}, {self.weight.shape[0]}, " | ||||
|             f"kernel_size={self.weight.shape[1]}, stride={self.stride}, " | ||||
|             f"padding={self.padding}, dilation={self.dilation}, " | ||||
|             f"groups={self.groups}, " | ||||
|             f"bias={'bias' in self}" | ||||
|         ) | ||||
|  | ||||
|     def __call__(self, x): | ||||
|         y = mx.conv1d(x, self.weight, self.stride, self.padding, self.dilation) | ||||
|         y = mx.conv1d( | ||||
|             x, self.weight, self.stride, self.padding, self.dilation, self.groups | ||||
|         ) | ||||
|         if "bias" in self: | ||||
|             y = y + self.bias | ||||
|         return y | ||||
|   | ||||
| @@ -650,6 +650,14 @@ class TestLayers(mlx_tests.MLXTestCase): | ||||
|         c = nn.Conv1d(in_channels=C_in, out_channels=C_out, kernel_size=ks, bias=False) | ||||
|         self.assertTrue("bias" not in c.parameters()) | ||||
|  | ||||
|         groups = C_in | ||||
|         c = nn.Conv1d( | ||||
|             in_channels=C_in, out_channels=C_out, kernel_size=ks, groups=groups | ||||
|         ) | ||||
|         y = c(x) | ||||
|         self.assertEqual(c.weight.shape, (C_out, ks, C_in // groups)) | ||||
|         self.assertEqual(y.shape, (N, L - ks + 1, C_out)) | ||||
|  | ||||
|     def test_conv2d(self): | ||||
|         x = mx.ones((4, 8, 8, 3)) | ||||
|         c = nn.Conv2d(3, 1, 8) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Lucas Newman
					Lucas Newman