mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Add groups to 2-D convolutions (#1129)
* Added groups to 2-D convolutions. Only implemented for **some** specializations. Also fixed 1D grouped convs with different kernel strides and added more tests. * fix channels condition
This commit is contained in:
@@ -123,9 +123,13 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
|
||||
# Groups tests
|
||||
N, C, O = (4, 32, 64)
|
||||
iH, kH, stride, padding = (31, 5, 1, 2)
|
||||
for group in (1, 2, 4, 8, 16, 32):
|
||||
run_conv1D(N, C, O, iH, kH, stride=1, padding=1, groups=group, dtype=dtype)
|
||||
for iH, kH, stride, padding in (
|
||||
(1, 1, 1, 0),
|
||||
(3, 3, 1, 0),
|
||||
(31, 5, 5, 2),
|
||||
):
|
||||
for group in (1, 2, 4, 8, 16, 32):
|
||||
run_conv1D(N, C, O, iH, kH, stride, padding, groups=group, dtype=dtype)
|
||||
|
||||
# Strided inputs tests
|
||||
for tpose_in, tpose_wt in (
|
||||
@@ -291,7 +295,9 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
kH, kW = kdim
|
||||
scale = 1.0 / math.sqrt(kH * kW * C)
|
||||
in_np = np.random.normal(0.0, scale, (N, iH, iW, C)).astype(np_dtype)
|
||||
wt_np = np.random.normal(0.0, 1.0, (O, kH, kW, C)).astype(np_dtype)
|
||||
wt_np = np.random.normal(0.0, 1.0, (O, kH, kW, int(C / groups))).astype(
|
||||
np_dtype
|
||||
)
|
||||
|
||||
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
|
||||
in_pt, wt_pt = map(
|
||||
@@ -334,6 +340,18 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
):
|
||||
run_conv2D(N, C, O, idim, kdim, stride, padding, dtype=dtype)
|
||||
|
||||
# Groups tests
|
||||
N, C, O = (4, 32, 64)
|
||||
for idim, kdim, stride, padding in (
|
||||
((1, 1), (1, 1), (1, 1), (0, 0)),
|
||||
((3, 3), (3, 1), (1, 1), (0, 0)),
|
||||
((31, 31), (5, 5), (5, 5), (2, 2)),
|
||||
):
|
||||
for group in (1, 2, 4, 8, 16, 32):
|
||||
run_conv2D(
|
||||
N, C, O, idim, kdim, stride, padding, groups=group, dtype=dtype
|
||||
)
|
||||
|
||||
@unittest.skipIf(not has_torch, "requires Torch")
|
||||
def test_torch_conv_2D_grad(self):
|
||||
def run_conv2D_grad(
|
||||
|
Reference in New Issue
Block a user