mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-05 16:13:52 +08:00
Depthwise Conv2D optimization (#2036)
- Add new specialized kernel for small kernel (kernels size <= 7), small strides (strides <= 2) depthwise 2d convolutions - Add related tests
This commit is contained in:
@@ -707,9 +707,11 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
flip=flip,
|
||||
np_dtype=np_dtype,
|
||||
):
|
||||
np.random.seed(0)
|
||||
scale = 1.0 / math.sqrt(np.prod(wt_shape[1:]))
|
||||
in_np = np.random.normal(0.0, scale, in_shape).astype(np_dtype)
|
||||
wt_np = np.random.normal(0.0, scale, wt_shape).astype(np_dtype)
|
||||
scale = min(0.3, scale)
|
||||
in_np = np.random.normal(0, scale, in_shape).astype(np_dtype)
|
||||
wt_np = np.random.normal(0, scale, wt_shape).astype(np_dtype)
|
||||
|
||||
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
|
||||
|
||||
@@ -1050,6 +1052,42 @@ class TestConv(mlx_tests.MLXTestCase):
|
||||
y2 = mx.conv2d(x, w, (1, 1), (1, 1), (1, 1), 1)
|
||||
self.assertTrue(mx.allclose(y1, y2))
|
||||
|
||||
@unittest.skipIf(not has_torch, "requires Torch")
|
||||
def test_torch_conv_depthwise(self):
|
||||
|
||||
# fmt: off
|
||||
shapes = (
|
||||
# N, H, W, C kH, kW, O, strides, padding, groups
|
||||
( 2, 16, 16, 32, 1, 1, 32, (2, 2), (1, 1), 32),
|
||||
( 1, 16, 16, 32, 3, 3, 32, (2, 2), (1, 1), 32),
|
||||
( 1, 32, 32, 32, 7, 7, 32, (1, 1), (3, 3), 32),
|
||||
( 3, 32, 32, 32, 5, 5, 32, (1, 2), (0, 0), 32),
|
||||
( 1, 32, 32, 32, 7, 7, 32, (2, 1), (1, 3), 32),
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
dtypes = [np.float32]
|
||||
if mx.default_device() == mx.gpu:
|
||||
dtypes += [np.float16]
|
||||
|
||||
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
|
||||
for dtype in dtypes:
|
||||
for flip in [False, True]:
|
||||
Cw = C // groups
|
||||
|
||||
self.__conv_general_test(
|
||||
(N, H, W, C),
|
||||
(O, kH, kW, Cw),
|
||||
strides,
|
||||
padding,
|
||||
kernel_dilation=1,
|
||||
input_dilation=1,
|
||||
groups=groups,
|
||||
flip=flip,
|
||||
np_dtype=dtype,
|
||||
atol=2e-5 if dtype == np.float32 else 5e-4,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user