Add support for grouped 1D convolutions to the nn API (#1444)

* Fix the weight shape for grouped convolutions from the nn API.

* Add tests.

* Pre-commit formatting.

* Add input validation.

* Use integer division instead of casting.

* docs

* nit

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Lucas Newman
2024-09-28 06:41:07 -07:00
committed by GitHub
parent b1e2b53c2d
commit 4a64d4bff1
2 changed files with 23 additions and 2 deletions

View File

@@ -650,6 +650,14 @@ class TestLayers(mlx_tests.MLXTestCase):
c = nn.Conv1d(in_channels=C_in, out_channels=C_out, kernel_size=ks, bias=False)
self.assertTrue("bias" not in c.parameters())
groups = C_in
c = nn.Conv1d(
in_channels=C_in, out_channels=C_out, kernel_size=ks, groups=groups
)
y = c(x)
self.assertEqual(c.weight.shape, (C_out, ks, C_in // groups))
self.assertEqual(y.shape, (N, L - ks + 1, C_out))
def test_conv2d(self):
x = mx.ones((4, 8, 8, 3))
c = nn.Conv2d(3, 1, 8)