Add groups to Conv1d (#948)

* Add conv1d grouped convs on CPU

* Add GPU support

* Parallelize inside metal kernel

* clenaup

* Update mlx/ops.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* New unfold kernel + remove unused code

* Remove copy and refactor

* Update vjp and reuse steel gemm

* Fixed groups on cpu

* Fix metal validation

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
Rifur13
2024-04-27 09:24:57 -04:00
committed by GitHub
parent 86f495985b
commit c4a471c99d
11 changed files with 633 additions and 55 deletions

View File

@@ -77,7 +77,9 @@ class TestConv(mlx_tests.MLXTestCase):
np_dtype = getattr(np, dtype)
np.random.seed(0)
in_np = np.random.normal(0, 1.0 / C, (N, iH, C)).astype(np_dtype)
wt_np = np.random.normal(0, 1.0 / C, (O, kH, C)).astype(np_dtype)
wt_np = np.random.normal(0, 1.0 / C, (O, kH, int(C / groups))).astype(
np_dtype
)
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
in_pt, wt_pt = map(
@@ -119,6 +121,12 @@ class TestConv(mlx_tests.MLXTestCase):
):
run_conv1D(N, C, O, iH, kH, stride, padding, dtype=dtype)
# Groups tests
N, C, O = (4, 32, 64)
iH, kH, stride, padding = (31, 5, 1, 2)
for group in (1, 2, 4, 8, 16, 32):
run_conv1D(N, C, O, iH, kH, stride=1, padding=1, groups=group, dtype=dtype)
# Strided inputs tests
for tpose_in, tpose_wt in (
((0, 2, 1), (0, 1, 2)),