mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	Add groups to 2-D convolutions (#1129)
* Added groups to 2-D convolutions. Only implemented for **some** specializations. Also fixed 1D grouped convs with different kernel strides and added more tests. * fix channels condition
This commit is contained in:
		@@ -123,9 +123,13 @@ class TestConv(mlx_tests.MLXTestCase):
 | 
			
		||||
 | 
			
		||||
        # Groups tests
 | 
			
		||||
        N, C, O = (4, 32, 64)
 | 
			
		||||
        iH, kH, stride, padding = (31, 5, 1, 2)
 | 
			
		||||
        for group in (1, 2, 4, 8, 16, 32):
 | 
			
		||||
            run_conv1D(N, C, O, iH, kH, stride=1, padding=1, groups=group, dtype=dtype)
 | 
			
		||||
        for iH, kH, stride, padding in (
 | 
			
		||||
            (1, 1, 1, 0),
 | 
			
		||||
            (3, 3, 1, 0),
 | 
			
		||||
            (31, 5, 5, 2),
 | 
			
		||||
        ):
 | 
			
		||||
            for group in (1, 2, 4, 8, 16, 32):
 | 
			
		||||
                run_conv1D(N, C, O, iH, kH, stride, padding, groups=group, dtype=dtype)
 | 
			
		||||
 | 
			
		||||
        # Strided inputs tests
 | 
			
		||||
        for tpose_in, tpose_wt in (
 | 
			
		||||
@@ -291,7 +295,9 @@ class TestConv(mlx_tests.MLXTestCase):
 | 
			
		||||
                kH, kW = kdim
 | 
			
		||||
                scale = 1.0 / math.sqrt(kH * kW * C)
 | 
			
		||||
                in_np = np.random.normal(0.0, scale, (N, iH, iW, C)).astype(np_dtype)
 | 
			
		||||
                wt_np = np.random.normal(0.0, 1.0, (O, kH, kW, C)).astype(np_dtype)
 | 
			
		||||
                wt_np = np.random.normal(0.0, 1.0, (O, kH, kW, int(C / groups))).astype(
 | 
			
		||||
                    np_dtype
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                in_mx, wt_mx = map(mx.array, (in_np, wt_np))
 | 
			
		||||
                in_pt, wt_pt = map(
 | 
			
		||||
@@ -334,6 +340,18 @@ class TestConv(mlx_tests.MLXTestCase):
 | 
			
		||||
                ):
 | 
			
		||||
                    run_conv2D(N, C, O, idim, kdim, stride, padding, dtype=dtype)
 | 
			
		||||
 | 
			
		||||
            # Groups tests
 | 
			
		||||
            N, C, O = (4, 32, 64)
 | 
			
		||||
            for idim, kdim, stride, padding in (
 | 
			
		||||
                ((1, 1), (1, 1), (1, 1), (0, 0)),
 | 
			
		||||
                ((3, 3), (3, 1), (1, 1), (0, 0)),
 | 
			
		||||
                ((31, 31), (5, 5), (5, 5), (2, 2)),
 | 
			
		||||
            ):
 | 
			
		||||
                for group in (1, 2, 4, 8, 16, 32):
 | 
			
		||||
                    run_conv2D(
 | 
			
		||||
                        N, C, O, idim, kdim, stride, padding, groups=group, dtype=dtype
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
    @unittest.skipIf(not has_torch, "requires Torch")
 | 
			
		||||
    def test_torch_conv_2D_grad(self):
 | 
			
		||||
        def run_conv2D_grad(
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user