mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Add groups to 2-D convolutions (#1129)
* Added groups to 2-D convolutions. Only implemented for **some** specializations. Also fixed 1D grouped convs with different kernel strides and added more tests. * fix channels condition
This commit is contained in:
		| @@ -123,9 +123,13 @@ class TestConv(mlx_tests.MLXTestCase): | ||||
|  | ||||
|         # Groups tests | ||||
|         N, C, O = (4, 32, 64) | ||||
|         iH, kH, stride, padding = (31, 5, 1, 2) | ||||
|         for group in (1, 2, 4, 8, 16, 32): | ||||
|             run_conv1D(N, C, O, iH, kH, stride=1, padding=1, groups=group, dtype=dtype) | ||||
|         for iH, kH, stride, padding in ( | ||||
|             (1, 1, 1, 0), | ||||
|             (3, 3, 1, 0), | ||||
|             (31, 5, 5, 2), | ||||
|         ): | ||||
|             for group in (1, 2, 4, 8, 16, 32): | ||||
|                 run_conv1D(N, C, O, iH, kH, stride, padding, groups=group, dtype=dtype) | ||||
|  | ||||
|         # Strided inputs tests | ||||
|         for tpose_in, tpose_wt in ( | ||||
| @@ -291,7 +295,9 @@ class TestConv(mlx_tests.MLXTestCase): | ||||
|                 kH, kW = kdim | ||||
|                 scale = 1.0 / math.sqrt(kH * kW * C) | ||||
|                 in_np = np.random.normal(0.0, scale, (N, iH, iW, C)).astype(np_dtype) | ||||
|                 wt_np = np.random.normal(0.0, 1.0, (O, kH, kW, C)).astype(np_dtype) | ||||
|                 wt_np = np.random.normal(0.0, 1.0, (O, kH, kW, int(C / groups))).astype( | ||||
|                     np_dtype | ||||
|                 ) | ||||
|  | ||||
|                 in_mx, wt_mx = map(mx.array, (in_np, wt_np)) | ||||
|                 in_pt, wt_pt = map( | ||||
| @@ -334,6 +340,18 @@ class TestConv(mlx_tests.MLXTestCase): | ||||
|                 ): | ||||
|                     run_conv2D(N, C, O, idim, kdim, stride, padding, dtype=dtype) | ||||
|  | ||||
|             # Groups tests | ||||
|             N, C, O = (4, 32, 64) | ||||
|             for idim, kdim, stride, padding in ( | ||||
|                 ((1, 1), (1, 1), (1, 1), (0, 0)), | ||||
|                 ((3, 3), (3, 1), (1, 1), (0, 0)), | ||||
|                 ((31, 31), (5, 5), (5, 5), (2, 2)), | ||||
|             ): | ||||
|                 for group in (1, 2, 4, 8, 16, 32): | ||||
|                     run_conv2D( | ||||
|                         N, C, O, idim, kdim, stride, padding, groups=group, dtype=dtype | ||||
|                     ) | ||||
|  | ||||
|     @unittest.skipIf(not has_torch, "requires Torch") | ||||
|     def test_torch_conv_2D_grad(self): | ||||
|         def run_conv2D_grad( | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Rifur13
					Rifur13