mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-06 08:24:39 +08:00
Add groups to Conv1d (#948)
* Add conv1d grouped convs on CPU * Add GPU support * Parallelize inside metal kernel * clenaup * Update mlx/ops.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * New unfold kernel + remove unused code * Remove copy and refactor * Update vjp and reuse steel gemm * Fixed groups on cpu * Fix metal validation --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
@@ -77,7 +77,9 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
np_dtype = getattr(np, dtype)
|
||||
np.random.seed(0)
|
||||
in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype)
|
||||
wt_np = np.random.normal(0, 1.0 / C, (O, kH, C)).astype(np_dtype)
|
||||
wt_np = np.random.normal(0, 1.0 / C, (O, kH, int(C / groups))).astype(
|
||||
np_dtype
|
||||
)
|
||||
|
||||
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
|
||||
in_pt, wt_pt = map(
|
||||
@@ -119,6 +121,12 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
):
|
||||
run_conv1D(N, C, O, iH, kH, stride, padding, dtype=dtype)
|
||||
|
||||
# Groups tests
|
||||
N, C, O = (4, 32, 64)
|
||||
iH, kH, stride, padding = (31, 5, 1, 2)
|
||||
for group in (1, 2, 4, 8, 16, 32):
|
||||
run_conv1D(N, C, O, iH, kH, stride=1, padding=1, groups=group, dtype=dtype)
|
||||
|
||||
# Strided inputs tests
|
||||
for tpose_in, tpose_wt in (
|
||||
((0, 2, 1), (0, 1, 2)),
|
||||
|
Reference in New Issue
Block a user