Depthwise Conv2D optimization (#2036)

- Add new specialized kernel for small kernel (kernels size <= 7), small strides (strides <= 2) depthwise 2d convolutions
- Add related tests
This commit is contained in:
Jagrit Digani
2025-04-03 09:42:04 -07:00
committed by GitHub
parent c41f7565ed
commit 8777fd104f
3 changed files with 232 additions and 4 deletions

View File

@@ -707,9 +707,11 @@ class TestConv(mlx_tests.MLXTestCase):
flip=flip,
np_dtype=np_dtype,
):
np.random.seed(0)
scale = 1.0 / math.sqrt(np.prod(wt_shape[1:]))
in_np = np.random.normal(0.0, scale, in_shape).astype(np_dtype)
wt_np = np.random.normal(0.0, scale, wt_shape).astype(np_dtype)
scale = min(0.3, scale)
in_np = np.random.normal(0, scale, in_shape).astype(np_dtype)
wt_np = np.random.normal(0, scale, wt_shape).astype(np_dtype)
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
@@ -1050,6 +1052,42 @@ class TestConv(mlx_tests.MLXTestCase):
y2 = mx.conv2d(x, w, (1, 1), (1, 1), (1, 1), 1)
self.assertTrue(mx.allclose(y1, y2))
@unittest.skipIf(not has_torch, "requires Torch")
def test_torch_conv_depthwise(self):
# fmt: off
shapes = (
# N, H, W, C kH, kW, O, strides, padding, groups
( 2, 16, 16, 32, 1, 1, 32, (2, 2), (1, 1), 32),
( 1, 16, 16, 32, 3, 3, 32, (2, 2), (1, 1), 32),
( 1, 32, 32, 32, 7, 7, 32, (1, 1), (3, 3), 32),
( 3, 32, 32, 32, 5, 5, 32, (1, 2), (0, 0), 32),
( 1, 32, 32, 32, 7, 7, 32, (2, 1), (1, 3), 32),
)
# fmt: on
dtypes = [np.float32]
if mx.default_device() == mx.gpu:
dtypes += [np.float16]
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
for dtype in dtypes:
for flip in [False, True]:
Cw = C // groups
self.__conv_general_test(
(N, H, W, C),
(O, kH, kW, Cw),
strides,
padding,
kernel_dilation=1,
input_dilation=1,
groups=groups,
flip=flip,
np_dtype=dtype,
atol=2e-5 if dtype == np.float32 else 5e-4,
)
if __name__ == "__main__":
unittest.main()