add groups in conv2d (#1569)

This commit is contained in:
Awni Hannun
2024-11-07 13:57:53 -08:00
committed by GitHub
parent 9a3842a2d9
commit 59247c2b62
3 changed files with 37 additions and 5 deletions

View File

@@ -706,6 +706,12 @@ class TestLayers(mlx_tests.MLXTestCase):
self.assertEqual(y.shape, (4, 4, 4, 8))
self.assertLess(mx.abs(y - c.weight.sum((1, 2, 3))).max(), 1e-4)
# 3x3 conv groups > 1
x = mx.ones((4, 7, 7, 4))
c = nn.Conv2d(4, 8, 3, padding=1, stride=1, groups=2)
y = c(x)
self.assertEqual(y.shape, (4, 7, 7, 8))
def test_sequential(self):
x = mx.ones((10, 2))
m = nn.Sequential(nn.Linear(2, 10), nn.ReLU(), nn.Linear(10, 1))