Add support for grouped 1D convolutions to the nn API (#1444)

* Fix the weight shape for grouped convolutions from the nn API.

* Add tests.

* Pre-commit formatting.

* Add input validation.

* Use integer division instead of casting.

* docs

* nit

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Lucas Newman 2024-09-28 06:41:07 -07:00 committed by GitHub
parent b1e2b53c2d
commit 4a64d4bff1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 23 additions and 2 deletions

View File

@ -25,6 +25,8 @@ class Conv1d(Module):
padding (int, optional): How many positions to 0-pad the input with. padding (int, optional): How many positions to 0-pad the input with.
Default: ``0``. Default: ``0``.
dilation (int, optional): The dilation of the convolution. dilation (int, optional): The dilation of the convolution.
groups (int, optional): The number of groups for the convolution.
Default: ``1``.
bias (bool, optional): If ``True`` add a learnable bias to the output. bias (bool, optional): If ``True`` add a learnable bias to the output.
Default: ``True`` Default: ``True``
""" """
@ -37,15 +39,22 @@ class Conv1d(Module):
stride: int = 1, stride: int = 1,
padding: int = 0, padding: int = 0,
dilation: int = 1, dilation: int = 1,
groups: int = 1,
bias: bool = True, bias: bool = True,
): ):
super().__init__() super().__init__()
if in_channels % groups != 0:
raise ValueError(
f"The number of input channels ({in_channels}) must be "
f"divisible by the number of groups ({groups})"
)
scale = math.sqrt(1 / (in_channels * kernel_size)) scale = math.sqrt(1 / (in_channels * kernel_size))
self.weight = mx.random.uniform( self.weight = mx.random.uniform(
low=-scale, low=-scale,
high=scale, high=scale,
shape=(out_channels, kernel_size, in_channels), shape=(out_channels, kernel_size, in_channels // groups),
) )
if bias: if bias:
self.bias = mx.zeros((out_channels,)) self.bias = mx.zeros((out_channels,))
@ -53,17 +62,21 @@ class Conv1d(Module):
self.padding = padding self.padding = padding
self.dilation = dilation self.dilation = dilation
self.stride = stride self.stride = stride
self.groups = groups
def _extra_repr(self): def _extra_repr(self):
return ( return (
f"{self.weight.shape[-1]}, {self.weight.shape[0]}, " f"{self.weight.shape[-1]}, {self.weight.shape[0]}, "
f"kernel_size={self.weight.shape[1]}, stride={self.stride}, " f"kernel_size={self.weight.shape[1]}, stride={self.stride}, "
f"padding={self.padding}, dilation={self.dilation}, " f"padding={self.padding}, dilation={self.dilation}, "
f"groups={self.groups}, "
f"bias={'bias' in self}" f"bias={'bias' in self}"
) )
def __call__(self, x): def __call__(self, x):
y = mx.conv1d(x, self.weight, self.stride, self.padding, self.dilation) y = mx.conv1d(
x, self.weight, self.stride, self.padding, self.dilation, self.groups
)
if "bias" in self: if "bias" in self:
y = y + self.bias y = y + self.bias
return y return y

View File

@ -650,6 +650,14 @@ class TestLayers(mlx_tests.MLXTestCase):
c = nn.Conv1d(in_channels=C_in, out_channels=C_out, kernel_size=ks, bias=False) c = nn.Conv1d(in_channels=C_in, out_channels=C_out, kernel_size=ks, bias=False)
self.assertTrue("bias" not in c.parameters()) self.assertTrue("bias" not in c.parameters())
groups = C_in
c = nn.Conv1d(
in_channels=C_in, out_channels=C_out, kernel_size=ks, groups=groups
)
y = c(x)
self.assertEqual(c.weight.shape, (C_out, ks, C_in // groups))
self.assertEqual(y.shape, (N, L - ks + 1, C_out))
def test_conv2d(self): def test_conv2d(self):
x = mx.ones((4, 8, 8, 3)) x = mx.ones((4, 8, 8, 3))
c = nn.Conv2d(3, 1, 8) c = nn.Conv2d(3, 1, 8)