mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-05 07:34:42 +08:00
Add support for grouped 1D convolutions to the nn API (#1444)
* Fix the weight shape for grouped convolutions from the nn API. * Add tests. * Pre-commit formatting. * Add input validation. * Use integer division instead of casting. * docs * nit --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -650,6 +650,14 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
c = nn.Conv1d(in_channels=C_in, out_channels=C_out, kernel_size=ks, bias=False)
|
||||
self.assertTrue("bias" not in c.parameters())
|
||||
|
||||
groups = C_in
|
||||
c = nn.Conv1d(
|
||||
in_channels=C_in, out_channels=C_out, kernel_size=ks, groups=groups
|
||||
)
|
||||
y = c(x)
|
||||
self.assertEqual(c.weight.shape, (C_out, ks, C_in // groups))
|
||||
self.assertEqual(y.shape, (N, L - ks + 1, C_out))
|
||||
|
||||
def test_conv2d(self):
|
||||
x = mx.ones((4, 8, 8, 3))
|
||||
c = nn.Conv2d(3, 1, 8)
|
||||
|
Reference in New Issue
Block a user