Add groups to 2-D convolutions (#1129)

* Added groups to 2-D convolutions. Only implemented for **some** specializations.

Also fixed 1D grouped convs with different kernel strides and added more tests.

* fix channels condition
This commit is contained in:
Rifur13
2024-05-22 23:01:44 -04:00
committed by GitHub
parent eb8321d863
commit 9401507336
9 changed files with 322 additions and 132 deletions

View File

@@ -123,9 +123,13 @@ class TestConv(mlx_tests.MLXTestCase):
# Groups tests
N, C, O = (4, 32, 64)
iH, kH, stride, padding = (31, 5, 1, 2)
for group in (1, 2, 4, 8, 16, 32):
run_conv1D(N, C, O, iH, kH, stride=1, padding=1, groups=group, dtype=dtype)
for iH, kH, stride, padding in (
(1, 1, 1, 0),
(3, 3, 1, 0),
(31, 5, 5, 2),
):
for group in (1, 2, 4, 8, 16, 32):
run_conv1D(N, C, O, iH, kH, stride, padding, groups=group, dtype=dtype)
# Strided inputs tests
for tpose_in, tpose_wt in (
@@ -291,7 +295,9 @@ class TestConv(mlx_tests.MLXTestCase):
kH, kW = kdim
scale = 1.0 / math.sqrt(kH * kW * C)
in_np = np.random.normal(0.0, scale, (N, iH, iW, C)).astype(np_dtype)
wt_np = np.random.normal(0.0, 1.0, (O, kH, kW, C)).astype(np_dtype)
wt_np = np.random.normal(0.0, 1.0, (O, kH, kW, int(C / groups))).astype(
np_dtype
)
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
in_pt, wt_pt = map(
@@ -334,6 +340,18 @@ class TestConv(mlx_tests.MLXTestCase):
):
run_conv2D(N, C, O, idim, kdim, stride, padding, dtype=dtype)
# Groups tests
N, C, O = (4, 32, 64)
for idim, kdim, stride, padding in (
((1, 1), (1, 1), (1, 1), (0, 0)),
((3, 3), (3, 1), (1, 1), (0, 0)),
((31, 31), (5, 5), (5, 5), (2, 2)),
):
for group in (1, 2, 4, 8, 16, 32):
run_conv2D(
N, C, O, idim, kdim, stride, padding, groups=group, dtype=dtype
)
@unittest.skipIf(not has_torch, "requires Torch")
def test_torch_conv_2D_grad(self):
def run_conv2D_grad(