mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 08:38:09 +08:00
Add groups to 2-D convolutions (#1129)
* Added groups to 2-D convolutions. Only implemented for **some** specializations. Also fixed 1D grouped convs with different kernel strides and added more tests. * fix channels condition
This commit is contained in:
@@ -28,11 +28,11 @@ def bench(f, a, b):
|
||||
return (e - s) * 1e-9
|
||||
|
||||
|
||||
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0)):
|
||||
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
def mx_conv_2D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = mx.conv2d(a, b, stride=strides, padding=padding)
|
||||
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||
ys.append(y)
|
||||
mx.eval(ys)
|
||||
return ys
|
||||
@@ -40,12 +40,12 @@ def make_mx_conv_2D(strides=(1, 1), padding=(0, 0)):
|
||||
return mx_conv_2D
|
||||
|
||||
|
||||
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0)):
|
||||
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||
@torch.no_grad()
|
||||
def pt_conv_2D(a, b):
|
||||
ys = []
|
||||
for i in range(N_iter_func):
|
||||
y = torch.conv2d(a, b, stride=strides, padding=padding)
|
||||
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||
ys.append(y)
|
||||
torch.mps.synchronize()
|
||||
return ys
|
||||
@@ -53,11 +53,13 @@ def make_pt_conv_2D(strides=(1, 1), padding=(0, 0)):
|
||||
return pt_conv_2D
|
||||
|
||||
|
||||
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, np_dtype):
|
||||
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
||||
|
||||
scale = 1.0 / math.sqrt(kH * kH * C)
|
||||
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
||||
b_np = np.random.uniform(-scale, scale, (O, kH, kW, C)).astype(np_dtype)
|
||||
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
|
||||
np_dtype
|
||||
)
|
||||
|
||||
a_mx = mx.array(a_np)
|
||||
b_mx = mx.array(b_np)
|
||||
@@ -67,15 +69,15 @@ def bench_shape(N, H, W, C, kH, kW, O, strides, padding, np_dtype):
|
||||
|
||||
torch.mps.synchronize()
|
||||
|
||||
f_mx = make_mx_conv_2D(strides, padding)
|
||||
f_pt = make_pt_conv_2D(strides, padding)
|
||||
f_mx = make_mx_conv_2D(strides, padding, groups)
|
||||
f_pt = make_pt_conv_2D(strides, padding, groups)
|
||||
|
||||
time_torch = bench(f_pt, a_pt, b_pt)
|
||||
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||
|
||||
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding)
|
||||
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
|
||||
out_pt = torch.conv2d(
|
||||
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding
|
||||
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||
)
|
||||
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
|
||||
out_pt = out_pt.numpy(force=True)
|
||||
@@ -84,7 +86,7 @@ def bench_shape(N, H, W, C, kH, kW, O, strides, padding, np_dtype):
|
||||
|
||||
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||
print(
|
||||
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||
)
|
||||
|
||||
return time_mlx, time_torch
|
||||
@@ -95,35 +97,40 @@ if __name__ == "__main__":
|
||||
|
||||
dtypes = ("float32",)
|
||||
shapes = (
|
||||
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2)),
|
||||
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2)),
|
||||
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2)),
|
||||
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2)),
|
||||
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2)),
|
||||
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2)),
|
||||
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2)),
|
||||
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2)),
|
||||
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2)),
|
||||
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2)),
|
||||
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2)),
|
||||
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2)),
|
||||
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2)),
|
||||
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2)),
|
||||
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2)),
|
||||
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2)),
|
||||
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
||||
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
||||
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 2),
|
||||
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 16),
|
||||
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 64),
|
||||
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
|
||||
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
|
||||
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||
)
|
||||
|
||||
for dtype in dtypes:
|
||||
print("(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, diff%")
|
||||
for N, H, W, C, kH, kW, O, strides, padding in shapes:
|
||||
print(
|
||||
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
|
||||
)
|
||||
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
|
||||
np_dtype = getattr(np, dtype)
|
||||
time_mlx, time_torch = bench_shape(
|
||||
N, H, W, C, kH, kW, O, strides, padding, np_dtype
|
||||
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
|
||||
)
|
||||
diff = time_torch / time_mlx - 1.0
|
||||
|
||||
print(
|
||||
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {100. * diff:+5.2f}%"
|
||||
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
|
||||
)
|
||||
if time_mlx >= 2.0 * time_torch:
|
||||
print("ATTENTION ^^^^^^^")
|
||||
|
Reference in New Issue
Block a user