mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Add groups to Conv1d (#948)
* Add conv1d grouped convs on CPU * Add GPU support * Parallelize inside metal kernel * clenaup * Update mlx/ops.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * New unfold kernel + remove unused code * Remove copy and refactor * Update vjp and reuse steel gemm * Fixed groups on cpu * Fix metal validation --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
		| @@ -77,7 +77,9 @@ class TestConv(mlx_tests.MLXTestCase): | ||||
|                 np_dtype = getattr(np, dtype) | ||||
|                 np.random.seed(0) | ||||
|                 in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype) | ||||
|                 wt_np = np.random.normal(0, 1.0 / C, (O, kH, C)).astype(np_dtype) | ||||
|                 wt_np = np.random.normal(0, 1.0 / C, (O, kH, int(C / groups))).astype( | ||||
|                     np_dtype | ||||
|                 ) | ||||
|  | ||||
|                 in_mx, wt_mx = map(mx.array, (in_np, wt_np)) | ||||
|                 in_pt, wt_pt = map( | ||||
| @@ -119,6 +121,12 @@ class TestConv(mlx_tests.MLXTestCase): | ||||
|                 ): | ||||
|                     run_conv1D(N, C, O, iH, kH, stride, padding, dtype=dtype) | ||||
|  | ||||
|         # Groups tests | ||||
|         N, C, O = (4, 32, 64) | ||||
|         iH, kH, stride, padding = (31, 5, 1, 2) | ||||
|         for group in (1, 2, 4, 8, 16, 32): | ||||
|             run_conv1D(N, C, O, iH, kH, stride=1, padding=1, groups=group, dtype=dtype) | ||||
|  | ||||
|         # Strided inputs tests | ||||
|         for tpose_in, tpose_wt in ( | ||||
|             ((0, 2, 1), (0, 1, 2)), | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Rifur13
					Rifur13