mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Add groups to 2-D convolutions (#1129)
* Added groups to 2-D convolutions. Only implemented for **some** specializations. Also fixed 1D grouped convs with different kernel strides and added more tests. * fix channels condition
This commit is contained in:
parent
eb8321d863
commit
9401507336
@ -16,6 +16,7 @@ MLX was developed with contributions from the following individuals:
|
|||||||
- Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays.
|
- Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays.
|
||||||
- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`
|
- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`
|
||||||
- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`.
|
- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`.
|
||||||
|
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
|
||||||
|
|
||||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||||
|
@ -28,11 +28,11 @@ def bench(f, a, b):
|
|||||||
return (e - s) * 1e-9
|
return (e - s) * 1e-9
|
||||||
|
|
||||||
|
|
||||||
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0)):
|
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||||
def mx_conv_2D(a, b):
|
def mx_conv_2D(a, b):
|
||||||
ys = []
|
ys = []
|
||||||
for i in range(N_iter_func):
|
for i in range(N_iter_func):
|
||||||
y = mx.conv2d(a, b, stride=strides, padding=padding)
|
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||||
ys.append(y)
|
ys.append(y)
|
||||||
mx.eval(ys)
|
mx.eval(ys)
|
||||||
return ys
|
return ys
|
||||||
@ -40,12 +40,12 @@ def make_mx_conv_2D(strides=(1, 1), padding=(0, 0)):
|
|||||||
return mx_conv_2D
|
return mx_conv_2D
|
||||||
|
|
||||||
|
|
||||||
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0)):
|
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def pt_conv_2D(a, b):
|
def pt_conv_2D(a, b):
|
||||||
ys = []
|
ys = []
|
||||||
for i in range(N_iter_func):
|
for i in range(N_iter_func):
|
||||||
y = torch.conv2d(a, b, stride=strides, padding=padding)
|
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
||||||
ys.append(y)
|
ys.append(y)
|
||||||
torch.mps.synchronize()
|
torch.mps.synchronize()
|
||||||
return ys
|
return ys
|
||||||
@ -53,11 +53,13 @@ def make_pt_conv_2D(strides=(1, 1), padding=(0, 0)):
|
|||||||
return pt_conv_2D
|
return pt_conv_2D
|
||||||
|
|
||||||
|
|
||||||
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, np_dtype):
|
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
||||||
|
|
||||||
scale = 1.0 / math.sqrt(kH * kH * C)
|
scale = 1.0 / math.sqrt(kH * kH * C)
|
||||||
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
||||||
b_np = np.random.uniform(-scale, scale, (O, kH, kW, C)).astype(np_dtype)
|
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
|
||||||
|
np_dtype
|
||||||
|
)
|
||||||
|
|
||||||
a_mx = mx.array(a_np)
|
a_mx = mx.array(a_np)
|
||||||
b_mx = mx.array(b_np)
|
b_mx = mx.array(b_np)
|
||||||
@ -67,15 +69,15 @@ def bench_shape(N, H, W, C, kH, kW, O, strides, padding, np_dtype):
|
|||||||
|
|
||||||
torch.mps.synchronize()
|
torch.mps.synchronize()
|
||||||
|
|
||||||
f_mx = make_mx_conv_2D(strides, padding)
|
f_mx = make_mx_conv_2D(strides, padding, groups)
|
||||||
f_pt = make_pt_conv_2D(strides, padding)
|
f_pt = make_pt_conv_2D(strides, padding, groups)
|
||||||
|
|
||||||
time_torch = bench(f_pt, a_pt, b_pt)
|
time_torch = bench(f_pt, a_pt, b_pt)
|
||||||
time_mlx = bench(f_mx, a_mx, b_mx)
|
time_mlx = bench(f_mx, a_mx, b_mx)
|
||||||
|
|
||||||
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding)
|
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
|
||||||
out_pt = torch.conv2d(
|
out_pt = torch.conv2d(
|
||||||
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding
|
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
||||||
)
|
)
|
||||||
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
|
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
|
||||||
out_pt = out_pt.numpy(force=True)
|
out_pt = out_pt.numpy(force=True)
|
||||||
@ -84,7 +86,7 @@ def bench_shape(N, H, W, C, kH, kW, O, strides, padding, np_dtype):
|
|||||||
|
|
||||||
if not np.allclose(out_pt, out_mx, atol=atol):
|
if not np.allclose(out_pt, out_mx, atol=atol):
|
||||||
print(
|
print(
|
||||||
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return time_mlx, time_torch
|
return time_mlx, time_torch
|
||||||
@ -95,35 +97,40 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
dtypes = ("float32",)
|
dtypes = ("float32",)
|
||||||
shapes = (
|
shapes = (
|
||||||
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2)),
|
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||||
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2)),
|
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||||
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2)),
|
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||||
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2)),
|
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
||||||
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2)),
|
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
|
||||||
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2)),
|
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||||
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2)),
|
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||||
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2)),
|
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||||
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2)),
|
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
||||||
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2)),
|
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 2),
|
||||||
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2)),
|
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 16),
|
||||||
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2)),
|
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 64),
|
||||||
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2)),
|
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||||
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2)),
|
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||||
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2)),
|
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
||||||
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2)),
|
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
|
||||||
|
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
|
||||||
|
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
|
||||||
|
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
|
||||||
)
|
)
|
||||||
|
|
||||||
for dtype in dtypes:
|
for dtype in dtypes:
|
||||||
print("(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, diff%")
|
print(
|
||||||
for N, H, W, C, kH, kW, O, strides, padding in shapes:
|
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
|
||||||
|
)
|
||||||
|
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
|
||||||
np_dtype = getattr(np, dtype)
|
np_dtype = getattr(np, dtype)
|
||||||
time_mlx, time_torch = bench_shape(
|
time_mlx, time_torch = bench_shape(
|
||||||
N, H, W, C, kH, kW, O, strides, padding, np_dtype
|
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
|
||||||
)
|
)
|
||||||
diff = time_torch / time_mlx - 1.0
|
diff = time_torch / time_mlx - 1.0
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {100. * diff:+5.2f}%"
|
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
|
||||||
)
|
)
|
||||||
if time_mlx >= 2.0 * time_torch:
|
if time_mlx >= 2.0 * time_torch:
|
||||||
print("ATTENTION ^^^^^^^")
|
print("ATTENTION ^^^^^^^")
|
||||||
|
@ -111,13 +111,17 @@ void slow_conv_2D(
|
|||||||
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
|
||||||
const int iH = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim
|
const int iH = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim
|
||||||
const int iW = 1 + in_dilation[1] * (in.shape(2) - 1); // Input spatial dim
|
const int iW = 1 + in_dilation[1] * (in.shape(2) - 1); // Input spatial dim
|
||||||
|
const int C = in.shape(3); // In channels
|
||||||
const int oH = out.shape(1); // Output spatial dim
|
const int oH = out.shape(1); // Output spatial dim
|
||||||
const int oW = out.shape(2); // Output spatial dim
|
const int oW = out.shape(2); // Output spatial dim
|
||||||
const int O = wt.shape(0); // Out channels
|
const int O = wt.shape(0); // Out channels
|
||||||
const int C = wt.shape(3); // In channels
|
|
||||||
const int wH = wt.shape(1); // Weight spatial dim
|
const int wH = wt.shape(1); // Weight spatial dim
|
||||||
const int wW = wt.shape(2); // Weight spatial dim
|
const int wW = wt.shape(2); // Weight spatial dim
|
||||||
|
|
||||||
|
const int groups = C / wt.shape(3);
|
||||||
|
const int C_per_group = wt.shape(3);
|
||||||
|
const int O_per_group = O / groups;
|
||||||
|
|
||||||
const size_t in_stride_N = in.strides()[0];
|
const size_t in_stride_N = in.strides()[0];
|
||||||
const size_t in_stride_H = in.strides()[1];
|
const size_t in_stride_H = in.strides()[1];
|
||||||
const size_t in_stride_W = in.strides()[2];
|
const size_t in_stride_W = in.strides()[2];
|
||||||
@ -141,33 +145,35 @@ void slow_conv_2D(
|
|||||||
int ih_base = oh * wt_strides[0] - padding[0];
|
int ih_base = oh * wt_strides[0] - padding[0];
|
||||||
int iw_base = ow * wt_strides[1] - padding[1];
|
int iw_base = ow * wt_strides[1] - padding[1];
|
||||||
|
|
||||||
for (int o = 0; o < O; ++o) {
|
for (int g = 0; g < groups; ++g) {
|
||||||
float r = 0.;
|
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
|
||||||
|
float r = 0.;
|
||||||
|
|
||||||
for (int wh = 0; wh < wH; ++wh) {
|
for (int wh = 0; wh < wH; ++wh) {
|
||||||
for (int ww = 0; ww < wW; ++ww) {
|
for (int ww = 0; ww < wW; ++ww) {
|
||||||
int wh_flip = flip ? wH - wh - 1 : wh;
|
int wh_flip = flip ? wH - wh - 1 : wh;
|
||||||
int ww_flip = flip ? wW - ww - 1 : ww;
|
int ww_flip = flip ? wW - ww - 1 : ww;
|
||||||
int ih = ih_base + wh_flip * wt_dilation[0];
|
int ih = ih_base + wh_flip * wt_dilation[0];
|
||||||
int iw = iw_base + ww_flip * wt_dilation[1];
|
int iw = iw_base + ww_flip * wt_dilation[1];
|
||||||
|
|
||||||
const T* wt_ptr_pt = wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
|
const T* wt_ptr_pt =
|
||||||
const T* in_ptr_pt = in_ptr + ih * in_stride_H + iw * in_stride_W;
|
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
|
||||||
|
const T* in_ptr_pt =
|
||||||
|
in_ptr + ih * in_stride_H + iw * in_stride_W;
|
||||||
|
|
||||||
for (int c = 0; c < C; ++c) {
|
for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) {
|
||||||
r += static_cast<float>(in_ptr_pt[0]) *
|
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
|
||||||
static_cast<float>(wt_ptr_pt[0]);
|
static_cast<float>(
|
||||||
in_ptr_pt += in_stride_C;
|
wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
|
||||||
wt_ptr_pt += wt_stride_C;
|
} // c
|
||||||
} // c
|
} // ww
|
||||||
|
} // wh
|
||||||
|
|
||||||
} // ww
|
out_ptr[0] = static_cast<T>(r);
|
||||||
} // wh
|
out_ptr += out_stride_O;
|
||||||
|
wt_ptr += wt_stride_O;
|
||||||
out_ptr[0] = static_cast<T>(r);
|
} // o
|
||||||
out_ptr += out_stride_O;
|
} // g
|
||||||
wt_ptr += wt_stride_O;
|
|
||||||
} // o
|
|
||||||
};
|
};
|
||||||
|
|
||||||
int jump_h = flip ? -wt_dilation[0] : wt_dilation[0];
|
int jump_h = flip ? -wt_dilation[0] : wt_dilation[0];
|
||||||
@ -219,41 +225,43 @@ void slow_conv_2D(
|
|||||||
int wh_base = base_h[oh % f_out_jump_h];
|
int wh_base = base_h[oh % f_out_jump_h];
|
||||||
int ww_base = base_w[ow % f_out_jump_w];
|
int ww_base = base_w[ow % f_out_jump_w];
|
||||||
|
|
||||||
for (int o = 0; o < O; ++o) {
|
for (int g = 0; g < groups; ++g) {
|
||||||
float r = 0.;
|
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
|
||||||
|
float r = 0.;
|
||||||
|
|
||||||
for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
|
for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
|
||||||
for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
|
for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
|
||||||
int wh_flip = flip ? wH - wh - 1 : wh;
|
int wh_flip = flip ? wH - wh - 1 : wh;
|
||||||
int ww_flip = flip ? wW - ww - 1 : ww;
|
int ww_flip = flip ? wW - ww - 1 : ww;
|
||||||
int ih = ih_base + wh_flip * wt_dilation[0];
|
int ih = ih_base + wh_flip * wt_dilation[0];
|
||||||
int iw = iw_base + ww_flip * wt_dilation[1];
|
int iw = iw_base + ww_flip * wt_dilation[1];
|
||||||
|
|
||||||
if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) {
|
if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) {
|
||||||
const T* wt_ptr_pt =
|
const T* wt_ptr_pt =
|
||||||
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
|
wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
|
||||||
|
|
||||||
int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih;
|
int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih;
|
||||||
int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw;
|
int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw;
|
||||||
|
|
||||||
const T* in_ptr_pt =
|
const T* in_ptr_pt =
|
||||||
in_ptr + ih_dil * in_stride_H + iw_dil * in_stride_W;
|
in_ptr + ih_dil * in_stride_H + iw_dil * in_stride_W;
|
||||||
|
|
||||||
for (int c = 0; c < C; ++c) {
|
for (int c = g * C_per_group; c < (g + 1) * C_per_group;
|
||||||
r += static_cast<float>(in_ptr_pt[0]) *
|
++c) {
|
||||||
static_cast<float>(wt_ptr_pt[0]);
|
r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
|
||||||
in_ptr_pt += in_stride_C;
|
static_cast<float>(
|
||||||
wt_ptr_pt += wt_stride_C;
|
wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
|
||||||
} // c
|
} // c
|
||||||
|
|
||||||
} // ih, iw check
|
} // ih, iw check
|
||||||
} // ww
|
} // ww
|
||||||
} // wh
|
} // wh
|
||||||
|
|
||||||
out_ptr[0] = static_cast<T>(r);
|
out_ptr[0] = static_cast<T>(r);
|
||||||
out_ptr += out_stride_O;
|
out_ptr += out_stride_O;
|
||||||
wt_ptr += wt_stride_O;
|
wt_ptr += wt_stride_O;
|
||||||
} // o
|
} // o
|
||||||
|
} // g
|
||||||
};
|
};
|
||||||
|
|
||||||
int oH_border_0 = 0;
|
int oH_border_0 = 0;
|
||||||
|
@ -257,15 +257,19 @@ void implicit_gemm_conv_2D_gpu(
|
|||||||
const array& wt,
|
const array& wt,
|
||||||
array out,
|
array out,
|
||||||
const MLXConvParams<2>& conv_params) {
|
const MLXConvParams<2>& conv_params) {
|
||||||
|
const int groups = conv_params.groups;
|
||||||
|
const int C_per_group = conv_params.C / conv_params.groups;
|
||||||
|
const int O_per_group = conv_params.O / conv_params.groups;
|
||||||
|
|
||||||
// Deduce implicit gemm size
|
// Deduce implicit gemm size
|
||||||
int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1];
|
const int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1];
|
||||||
int implicit_N = conv_params.O;
|
const int implicit_N = O_per_group;
|
||||||
int implicit_K = conv_params.wS[0] * conv_params.wS[1] * conv_params.C;
|
const int implicit_K = conv_params.wS[0] * conv_params.wS[1] * C_per_group;
|
||||||
|
|
||||||
// Determine block and warp tiles
|
// Determine block and warp tiles
|
||||||
int wm = 2, wn = 2;
|
int wm = 2, wn = 2;
|
||||||
|
|
||||||
int bm = implicit_M >= 8192 && conv_params.C >= 64 ? 64 : 32;
|
int bm = implicit_M >= 8192 && C_per_group >= 64 ? 64 : 32;
|
||||||
int bn = (bm == 64 || implicit_N >= 64) ? 64 : 32;
|
int bn = (bm == 64 || implicit_N >= 64) ? 64 : 32;
|
||||||
int bk = 16;
|
int bk = 16;
|
||||||
|
|
||||||
@ -281,15 +285,15 @@ void implicit_gemm_conv_2D_gpu(
|
|||||||
|
|
||||||
// Fix small channel specialization
|
// Fix small channel specialization
|
||||||
int n_channel_specialization = 0;
|
int n_channel_specialization = 0;
|
||||||
int channel_k_iters = ((conv_params.C + bk - 1) / bk);
|
int channel_k_iters = ((C_per_group + bk - 1) / bk);
|
||||||
int gemm_k_iters = conv_params.wS[0] * conv_params.wS[1] * channel_k_iters;
|
int gemm_k_iters = conv_params.wS[0] * conv_params.wS[1] * channel_k_iters;
|
||||||
|
|
||||||
if (conv_params.C <= 2) {
|
if (C_per_group <= 2) {
|
||||||
gemm_k_iters = (implicit_K + bk - 1) / bk;
|
gemm_k_iters = (implicit_K + bk - 1) / bk;
|
||||||
n_channel_specialization = conv_params.C;
|
n_channel_specialization = C_per_group;
|
||||||
} else if (conv_params.C <= 4) {
|
} else if (C_per_group <= 4) {
|
||||||
gemm_k_iters = ((conv_params.wS[0] * conv_params.wS[1] * 4) + bk - 1) / bk;
|
gemm_k_iters = ((conv_params.wS[0] * conv_params.wS[1] * 4) + bk - 1) / bk;
|
||||||
n_channel_specialization = conv_params.C;
|
n_channel_specialization = C_per_group;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool small_filter = (!n_channel_specialization) &&
|
bool small_filter = (!n_channel_specialization) &&
|
||||||
@ -340,7 +344,7 @@ void implicit_gemm_conv_2D_gpu(
|
|||||||
size_t grid_dim_x = tn * tile;
|
size_t grid_dim_x = tn * tile;
|
||||||
|
|
||||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||||
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, 1);
|
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, groups);
|
||||||
|
|
||||||
// Encode arrays
|
// Encode arrays
|
||||||
compute_encoder.set_input_array(in, 0);
|
compute_encoder.set_input_array(in, 0);
|
||||||
@ -703,6 +707,7 @@ void conv_2D_gpu(
|
|||||||
const std::vector<int>& wt_strides,
|
const std::vector<int>& wt_strides,
|
||||||
const std::vector<int>& wt_dilation,
|
const std::vector<int>& wt_dilation,
|
||||||
const std::vector<int>& in_dilation,
|
const std::vector<int>& in_dilation,
|
||||||
|
const int groups,
|
||||||
bool flip,
|
bool flip,
|
||||||
std::vector<array>& copies) {
|
std::vector<array>& copies) {
|
||||||
// Make conv params
|
// Make conv params
|
||||||
@ -718,12 +723,12 @@ void conv_2D_gpu(
|
|||||||
/* const int kdil[NDIM] = */ {wt_dilation[0], wt_dilation[1]},
|
/* const int kdil[NDIM] = */ {wt_dilation[0], wt_dilation[1]},
|
||||||
/* const int idil[NDIM] = */ {in_dilation[0], in_dilation[1]},
|
/* const int idil[NDIM] = */ {in_dilation[0], in_dilation[1]},
|
||||||
/* const size_t in_strides[NDIM + 2] = */
|
/* const size_t in_strides[NDIM + 2] = */
|
||||||
{in.strides()[0], in.strides()[1], in.strides()[2], in.strides()[3]},
|
{in.strides(0), in.strides(1), in.strides(2), in.strides(3)},
|
||||||
/* const size_t wt_strides[NDIM + 2] = */
|
/* const size_t wt_strides[NDIM + 2] = */
|
||||||
{wt.strides()[0], wt.strides()[1], wt.strides()[2], wt.strides()[3]},
|
{wt.strides(0), wt.strides(1), wt.strides(2), wt.strides(3)},
|
||||||
/* const size_t out_strides[NDIM + 2] = */
|
/* const size_t out_strides[NDIM + 2] = */
|
||||||
{out.strides()[0], out.strides()[1], out.strides()[2], out.strides()[3]},
|
{out.strides(0), out.strides(1), out.strides(2), out.strides(3)},
|
||||||
/* const int groups = */ 1,
|
/* const int groups = */ groups,
|
||||||
/* const bool flip = */ flip,
|
/* const bool flip = */ flip,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -735,6 +740,18 @@ void conv_2D_gpu(
|
|||||||
bool channels_large = (conv_params.C + conv_params.O) >= 512;
|
bool channels_large = (conv_params.C + conv_params.O) >= 512;
|
||||||
bool channels_med = (conv_params.C + conv_params.O) >= 256;
|
bool channels_med = (conv_params.C + conv_params.O) >= 256;
|
||||||
|
|
||||||
|
if (groups > 1) {
|
||||||
|
const int C_per_group = conv_params.C / groups;
|
||||||
|
const int O_per_group = conv_params.O / groups;
|
||||||
|
|
||||||
|
if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) &&
|
||||||
|
(O_per_group <= 16 || O_per_group % 16 == 0)) {
|
||||||
|
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||||
|
} else {
|
||||||
|
return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Direct to winograd conv
|
// Direct to winograd conv
|
||||||
if (!flip && is_stride_one && is_kdil_one && is_idil_one &&
|
if (!flip && is_stride_one && is_kdil_one && is_idil_one &&
|
||||||
conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
|
conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
|
||||||
@ -860,6 +877,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
kernel_strides_,
|
kernel_strides_,
|
||||||
kernel_dilation_,
|
kernel_dilation_,
|
||||||
input_dilation_,
|
input_dilation_,
|
||||||
|
groups_,
|
||||||
flip_,
|
flip_,
|
||||||
copies);
|
copies);
|
||||||
}
|
}
|
||||||
|
@ -109,6 +109,7 @@ template <typename T, int N>
|
|||||||
bool valid = n < params->N;
|
bool valid = n < params->N;
|
||||||
|
|
||||||
// Unroll dimensions
|
// Unroll dimensions
|
||||||
|
int kernel_stride = 1;
|
||||||
for (int i = N - 1; i >= 0; --i) {
|
for (int i = N - 1; i >= 0; --i) {
|
||||||
int os_ = (oS % params->oS[i]);
|
int os_ = (oS % params->oS[i]);
|
||||||
int ws_ = (wS % params->wS[i]);
|
int ws_ = (wS % params->wS[i]);
|
||||||
@ -125,7 +126,8 @@ template <typename T, int N>
|
|||||||
oS /= params->oS[i];
|
oS /= params->oS[i];
|
||||||
wS /= params->wS[i];
|
wS /= params->wS[i];
|
||||||
|
|
||||||
out += ws_ * params->str[i];
|
out += ws_ * kernel_stride;
|
||||||
|
kernel_stride *= params->wS[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
if (valid) {
|
if (valid) {
|
||||||
|
@ -133,9 +133,15 @@ implicit_gemm_conv_2d(
|
|||||||
const int c_col = tid_x * BN;
|
const int c_col = tid_x * BN;
|
||||||
const int K = gemm_params->K;
|
const int K = gemm_params->K;
|
||||||
const int N = gemm_params->N;
|
const int N = gemm_params->N;
|
||||||
|
const int C_per_group = params->C / params->groups;
|
||||||
|
|
||||||
|
// Groups
|
||||||
|
A += tid.z * C_per_group;
|
||||||
|
B += tid.z * N * K;
|
||||||
|
C += tid.z * N;
|
||||||
|
|
||||||
B += c_col * K;
|
B += c_col * K;
|
||||||
C += c_row * N + c_col;
|
C += c_row * (N * params->groups) + c_col;
|
||||||
|
|
||||||
const int2 offsets_a(0, c_row);
|
const int2 offsets_a(0, c_row);
|
||||||
const int2 offsets_b(0, c_col);
|
const int2 offsets_b(0, c_col);
|
||||||
@ -171,7 +177,8 @@ implicit_gemm_conv_2d(
|
|||||||
// Store results to device memory
|
// Store results to device memory
|
||||||
short tgp_bm = min(BM, gemm_params->M - c_row);
|
short tgp_bm = min(BM, gemm_params->M - c_row);
|
||||||
short tgp_bn = min(BN, gemm_params->N - c_col);
|
short tgp_bn = min(BN, gemm_params->N - c_col);
|
||||||
mma_op.store_result_safe(C, N, short2(tgp_bn, tgp_bm));
|
const int ldc = N * params->groups;
|
||||||
|
mma_op.store_result_safe(C, ldc, short2(tgp_bn, tgp_bm));
|
||||||
}
|
}
|
||||||
|
|
||||||
#define instantiate_implicit_conv_2d( \
|
#define instantiate_implicit_conv_2d( \
|
||||||
|
@ -3180,9 +3180,9 @@ array conv_general(
|
|||||||
bool flip /* = false */,
|
bool flip /* = false */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
// Run checks
|
// Run checks
|
||||||
if (groups != 1 && in.ndim() != 3) {
|
if (groups != 1 && in.ndim() != 3 && in.ndim() != 4) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[conv] Can only handle groups != 1 in 1D convolutions.");
|
"[conv] Can only handle groups != 1 in 1D or 2D convolutions.");
|
||||||
}
|
}
|
||||||
|
|
||||||
int spatial_dims = in.ndim() - 2;
|
int spatial_dims = in.ndim() - 2;
|
||||||
|
@ -123,9 +123,13 @@ class TestConv(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
# Groups tests
|
# Groups tests
|
||||||
N, C, O = (4, 32, 64)
|
N, C, O = (4, 32, 64)
|
||||||
iH, kH, stride, padding = (31, 5, 1, 2)
|
for iH, kH, stride, padding in (
|
||||||
for group in (1, 2, 4, 8, 16, 32):
|
(1, 1, 1, 0),
|
||||||
run_conv1D(N, C, O, iH, kH, stride=1, padding=1, groups=group, dtype=dtype)
|
(3, 3, 1, 0),
|
||||||
|
(31, 5, 5, 2),
|
||||||
|
):
|
||||||
|
for group in (1, 2, 4, 8, 16, 32):
|
||||||
|
run_conv1D(N, C, O, iH, kH, stride, padding, groups=group, dtype=dtype)
|
||||||
|
|
||||||
# Strided inputs tests
|
# Strided inputs tests
|
||||||
for tpose_in, tpose_wt in (
|
for tpose_in, tpose_wt in (
|
||||||
@ -291,7 +295,9 @@ class TestConv(mlx_tests.MLXTestCase):
|
|||||||
kH, kW = kdim
|
kH, kW = kdim
|
||||||
scale = 1.0 / math.sqrt(kH * kW * C)
|
scale = 1.0 / math.sqrt(kH * kW * C)
|
||||||
in_np = np.random.normal(0.0, scale, (N, iH, iW, C)).astype(np_dtype)
|
in_np = np.random.normal(0.0, scale, (N, iH, iW, C)).astype(np_dtype)
|
||||||
wt_np = np.random.normal(0.0, 1.0, (O, kH, kW, C)).astype(np_dtype)
|
wt_np = np.random.normal(0.0, 1.0, (O, kH, kW, int(C / groups))).astype(
|
||||||
|
np_dtype
|
||||||
|
)
|
||||||
|
|
||||||
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
|
in_mx, wt_mx = map(mx.array, (in_np, wt_np))
|
||||||
in_pt, wt_pt = map(
|
in_pt, wt_pt = map(
|
||||||
@ -334,6 +340,18 @@ class TestConv(mlx_tests.MLXTestCase):
|
|||||||
):
|
):
|
||||||
run_conv2D(N, C, O, idim, kdim, stride, padding, dtype=dtype)
|
run_conv2D(N, C, O, idim, kdim, stride, padding, dtype=dtype)
|
||||||
|
|
||||||
|
# Groups tests
|
||||||
|
N, C, O = (4, 32, 64)
|
||||||
|
for idim, kdim, stride, padding in (
|
||||||
|
((1, 1), (1, 1), (1, 1), (0, 0)),
|
||||||
|
((3, 3), (3, 1), (1, 1), (0, 0)),
|
||||||
|
((31, 31), (5, 5), (5, 5), (2, 2)),
|
||||||
|
):
|
||||||
|
for group in (1, 2, 4, 8, 16, 32):
|
||||||
|
run_conv2D(
|
||||||
|
N, C, O, idim, kdim, stride, padding, groups=group, dtype=dtype
|
||||||
|
)
|
||||||
|
|
||||||
@unittest.skipIf(not has_torch, "requires Torch")
|
@unittest.skipIf(not has_torch, "requires Torch")
|
||||||
def test_torch_conv_2D_grad(self):
|
def test_torch_conv_2D_grad(self):
|
||||||
def run_conv2D_grad(
|
def run_conv2D_grad(
|
||||||
|
@ -3268,22 +3268,22 @@ TEST_CASE("test conv1d") {
|
|||||||
float16);
|
float16);
|
||||||
|
|
||||||
auto expected = array(
|
auto expected = array(
|
||||||
{1.5685,
|
{1.56836,
|
||||||
0.5672,
|
0.567383,
|
||||||
1.8121,
|
1.8125,
|
||||||
1.2948,
|
1.29492,
|
||||||
2.3448,
|
2.34375,
|
||||||
1.6104,
|
1.61035,
|
||||||
2.7743,
|
2.77539,
|
||||||
1.6126,
|
1.61328,
|
||||||
1.4056,
|
1.40527,
|
||||||
0.9331,
|
0.933105,
|
||||||
1.8739,
|
1.87402,
|
||||||
1.0909},
|
1.09082},
|
||||||
{1, 3, 4});
|
{1, 3, 4});
|
||||||
|
|
||||||
auto out = conv1d(in, wt, stride, padding, /* dilation= */ 1, groups);
|
auto out = conv1d(in, wt, stride, padding, /* dilation= */ 1, groups);
|
||||||
CHECK(allclose(out, expected, /* rtol = */ 1.0e-3).item<bool>());
|
CHECK(allclose(out, expected).item<bool>());
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
@ -3309,22 +3309,151 @@ TEST_CASE("test conv1d") {
|
|||||||
{4, 3, 1});
|
{4, 3, 1});
|
||||||
|
|
||||||
auto expected = array(
|
auto expected = array(
|
||||||
{1.0703,
|
{1.07007,
|
||||||
0.7533,
|
0.753201,
|
||||||
0.7007,
|
0.700818,
|
||||||
0.4681,
|
0.468176,
|
||||||
1.1859,
|
1.18568,
|
||||||
0.9117,
|
0.91152,
|
||||||
0.9565,
|
0.956607,
|
||||||
0.6111,
|
0.611213,
|
||||||
0.6416,
|
0.641404,
|
||||||
0.5665,
|
0.566401,
|
||||||
0.9074,
|
0.907472,
|
||||||
0.0605},
|
0.0605397},
|
||||||
{1, 3, 4});
|
{1, 3, 4});
|
||||||
|
|
||||||
auto out = conv1d(in, wt, stride, padding, /* dilation= */ 1, groups);
|
auto out = conv1d(in, wt, stride, padding, /* dilation= */ 1, groups);
|
||||||
CHECK(allclose(out, expected, /* rtol = */ 1.0e-3).item<bool>());
|
CHECK(allclose(out, expected).item<bool>());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test conv2d") {
|
||||||
|
auto in = array(
|
||||||
|
{0.57429284,
|
||||||
|
-0.21628855,
|
||||||
|
-0.18673691,
|
||||||
|
-0.3793517,
|
||||||
|
|
||||||
|
0.3059678,
|
||||||
|
-0.8137168,
|
||||||
|
0.6168841,
|
||||||
|
-0.26912728},
|
||||||
|
{1, 2, 2, 2});
|
||||||
|
|
||||||
|
std::pair<int, int> kernel{2, 2};
|
||||||
|
std::pair<int, int> stride{1, 1};
|
||||||
|
std::pair<int, int> padding{0, 0};
|
||||||
|
|
||||||
|
{
|
||||||
|
int groups = 1;
|
||||||
|
|
||||||
|
auto wt = array(
|
||||||
|
{0.3190391, -0.24937038, 1.4621079, -2.0601406, -0.3224172,
|
||||||
|
-0.38405436, 1.1337694, -1.0998913, -0.1724282, -0.8778584,
|
||||||
|
0.04221375, 0.58281523, -1.1006192, 1.1447237, 0.9015907,
|
||||||
|
0.50249434, 0.90085596, -0.68372786, -0.12289023, -0.93576944,
|
||||||
|
-0.26788807, 0.53035545, -0.69166076, -0.39675352, -0.6871727,
|
||||||
|
-0.84520566, -0.6712461, -0.0126646, -1.1173104, 0.2344157,
|
||||||
|
1.6598022, 0.74204415},
|
||||||
|
{4, 2, 2, 2});
|
||||||
|
|
||||||
|
auto expected =
|
||||||
|
array({1.9549234, -0.98542136, 0.2097499, 0.20991313}, {1, 1, 1, 4});
|
||||||
|
auto out = conv2d(in, wt, stride, padding, /* dilation= */ {1, 1}, groups);
|
||||||
|
CHECK(allclose(out, expected).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
int groups = 2;
|
||||||
|
auto wt = array(
|
||||||
|
{0.3190391,
|
||||||
|
-0.24937038,
|
||||||
|
|
||||||
|
1.46210794,
|
||||||
|
-2.06014071,
|
||||||
|
|
||||||
|
-0.3224172,
|
||||||
|
-0.38405435,
|
||||||
|
|
||||||
|
1.13376944,
|
||||||
|
-1.09989127,
|
||||||
|
|
||||||
|
-0.17242821,
|
||||||
|
-0.87785842,
|
||||||
|
|
||||||
|
0.04221375,
|
||||||
|
0.58281521,
|
||||||
|
|
||||||
|
-1.10061918,
|
||||||
|
1.14472371,
|
||||||
|
|
||||||
|
0.90159072,
|
||||||
|
0.50249434},
|
||||||
|
{4, 2, 2, 1});
|
||||||
|
|
||||||
|
auto expected = array(
|
||||||
|
{-0.59372161, -0.44505326, 0.17910982, -1.06507601}, {1, 1, 1, 4});
|
||||||
|
|
||||||
|
auto out = conv2d(in, wt, stride, padding, /* dilation= */ {1, 1}, groups);
|
||||||
|
CHECK(allclose(out, expected).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
in = array(
|
||||||
|
{0.57429284,
|
||||||
|
-0.21628855,
|
||||||
|
-0.18673691,
|
||||||
|
-0.3793517,
|
||||||
|
|
||||||
|
0.3059678,
|
||||||
|
-0.8137168,
|
||||||
|
0.6168841,
|
||||||
|
-0.26912728,
|
||||||
|
|
||||||
|
0.57429284,
|
||||||
|
-0.21628855,
|
||||||
|
-0.18673691,
|
||||||
|
-0.3793517,
|
||||||
|
|
||||||
|
0.3059678,
|
||||||
|
-0.8137168,
|
||||||
|
0.6168841,
|
||||||
|
-0.26912728},
|
||||||
|
{2, 2, 2, 2});
|
||||||
|
|
||||||
|
int groups = 2;
|
||||||
|
auto wt = array(
|
||||||
|
{0.3190391,
|
||||||
|
-0.24937038,
|
||||||
|
|
||||||
|
1.46210794,
|
||||||
|
-2.06014071,
|
||||||
|
|
||||||
|
-0.3224172,
|
||||||
|
-0.38405435,
|
||||||
|
|
||||||
|
1.13376944,
|
||||||
|
-1.09989127,
|
||||||
|
|
||||||
|
-0.17242821,
|
||||||
|
-0.87785842,
|
||||||
|
|
||||||
|
0.04221375,
|
||||||
|
0.58281521,
|
||||||
|
|
||||||
|
-1.10061918,
|
||||||
|
1.14472371,
|
||||||
|
|
||||||
|
0.90159072,
|
||||||
|
0.50249434},
|
||||||
|
{4, 2, 2, 1});
|
||||||
|
|
||||||
|
auto expected = array(
|
||||||
|
{-0.59372161, -0.44505326, 0.17910982, -1.06507601}, {1, 1, 1, 4});
|
||||||
|
|
||||||
|
auto out = conv2d(in, wt, stride, padding, /* dilation= */ {1, 1}, groups);
|
||||||
|
CHECK(allclose(out, expected).item<bool>());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user