Add groups to 2-D convolutions (#1129)

* Added groups to 2-D convolutions. Only implemented for **some** specializations.

Also fixed 1D grouped convs with different kernel strides and added more tests.

* fix channels condition
This commit is contained in:
Rifur13 2024-05-22 23:01:44 -04:00 committed by GitHub
parent eb8321d863
commit 9401507336
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 322 additions and 132 deletions

View File

@ -16,6 +16,7 @@ MLX was developed with contributions from the following individuals:
- Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays. - Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays.
- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention` - Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`
- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`. - AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`.
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
<a href="https://github.com/ml-explore/mlx/graphs/contributors"> <a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" /> <img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />

View File

@ -28,11 +28,11 @@ def bench(f, a, b):
return (e - s) * 1e-9 return (e - s) * 1e-9
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0)): def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
def mx_conv_2D(a, b): def mx_conv_2D(a, b):
ys = [] ys = []
for i in range(N_iter_func): for i in range(N_iter_func):
y = mx.conv2d(a, b, stride=strides, padding=padding) y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y) ys.append(y)
mx.eval(ys) mx.eval(ys)
return ys return ys
@ -40,12 +40,12 @@ def make_mx_conv_2D(strides=(1, 1), padding=(0, 0)):
return mx_conv_2D return mx_conv_2D
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0)): def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
@torch.no_grad() @torch.no_grad()
def pt_conv_2D(a, b): def pt_conv_2D(a, b):
ys = [] ys = []
for i in range(N_iter_func): for i in range(N_iter_func):
y = torch.conv2d(a, b, stride=strides, padding=padding) y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
ys.append(y) ys.append(y)
torch.mps.synchronize() torch.mps.synchronize()
return ys return ys
@ -53,11 +53,13 @@ def make_pt_conv_2D(strides=(1, 1), padding=(0, 0)):
return pt_conv_2D return pt_conv_2D
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, np_dtype): def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
scale = 1.0 / math.sqrt(kH * kH * C) scale = 1.0 / math.sqrt(kH * kH * C)
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype) a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
b_np = np.random.uniform(-scale, scale, (O, kH, kW, C)).astype(np_dtype) b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
np_dtype
)
a_mx = mx.array(a_np) a_mx = mx.array(a_np)
b_mx = mx.array(b_np) b_mx = mx.array(b_np)
@ -67,15 +69,15 @@ def bench_shape(N, H, W, C, kH, kW, O, strides, padding, np_dtype):
torch.mps.synchronize() torch.mps.synchronize()
f_mx = make_mx_conv_2D(strides, padding) f_mx = make_mx_conv_2D(strides, padding, groups)
f_pt = make_pt_conv_2D(strides, padding) f_pt = make_pt_conv_2D(strides, padding, groups)
time_torch = bench(f_pt, a_pt, b_pt) time_torch = bench(f_pt, a_pt, b_pt)
time_mlx = bench(f_mx, a_mx, b_mx) time_mlx = bench(f_mx, a_mx, b_mx)
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding) out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
out_pt = torch.conv2d( out_pt = torch.conv2d(
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
) )
out_pt = torch.permute(out_pt, (0, 2, 3, 1)) out_pt = torch.permute(out_pt, (0, 2, 3, 1))
out_pt = out_pt.numpy(force=True) out_pt = out_pt.numpy(force=True)
@ -84,7 +86,7 @@ def bench_shape(N, H, W, C, kH, kW, O, strides, padding, np_dtype):
if not np.allclose(out_pt, out_mx, atol=atol): if not np.allclose(out_pt, out_mx, atol=atol):
print( print(
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}" f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
) )
return time_mlx, time_torch return time_mlx, time_torch
@ -95,35 +97,40 @@ if __name__ == "__main__":
dtypes = ("float32",) dtypes = ("float32",)
shapes = ( shapes = (
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2)), (4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2)), (4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2)), (4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2)), (4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2)), (4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2)), (4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2)), (4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2)), (4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2)), (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2)), (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 2),
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2)), (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 16),
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2)), (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 64),
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2)), (4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2)), (4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2)), (4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2)), (4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
) )
for dtype in dtypes: for dtype in dtypes:
print("(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, diff%") print(
for N, H, W, C, kH, kW, O, strides, padding in shapes: "(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
)
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
np_dtype = getattr(np, dtype) np_dtype = getattr(np, dtype)
time_mlx, time_torch = bench_shape( time_mlx, time_torch = bench_shape(
N, H, W, C, kH, kW, O, strides, padding, np_dtype N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
) )
diff = time_torch / time_mlx - 1.0 diff = time_torch / time_mlx - 1.0
print( print(
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {100. * diff:+5.2f}%" f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
) )
if time_mlx >= 2.0 * time_torch: if time_mlx >= 2.0 * time_torch:
print("ATTENTION ^^^^^^^") print("ATTENTION ^^^^^^^")

View File

@ -111,13 +111,17 @@ void slow_conv_2D(
const int N = in.shape(0); // Batch size, should be the same as out.shape(0) const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
const int iH = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim const int iH = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim
const int iW = 1 + in_dilation[1] * (in.shape(2) - 1); // Input spatial dim const int iW = 1 + in_dilation[1] * (in.shape(2) - 1); // Input spatial dim
const int C = in.shape(3); // In channels
const int oH = out.shape(1); // Output spatial dim const int oH = out.shape(1); // Output spatial dim
const int oW = out.shape(2); // Output spatial dim const int oW = out.shape(2); // Output spatial dim
const int O = wt.shape(0); // Out channels const int O = wt.shape(0); // Out channels
const int C = wt.shape(3); // In channels
const int wH = wt.shape(1); // Weight spatial dim const int wH = wt.shape(1); // Weight spatial dim
const int wW = wt.shape(2); // Weight spatial dim const int wW = wt.shape(2); // Weight spatial dim
const int groups = C / wt.shape(3);
const int C_per_group = wt.shape(3);
const int O_per_group = O / groups;
const size_t in_stride_N = in.strides()[0]; const size_t in_stride_N = in.strides()[0];
const size_t in_stride_H = in.strides()[1]; const size_t in_stride_H = in.strides()[1];
const size_t in_stride_W = in.strides()[2]; const size_t in_stride_W = in.strides()[2];
@ -141,7 +145,8 @@ void slow_conv_2D(
int ih_base = oh * wt_strides[0] - padding[0]; int ih_base = oh * wt_strides[0] - padding[0];
int iw_base = ow * wt_strides[1] - padding[1]; int iw_base = ow * wt_strides[1] - padding[1];
for (int o = 0; o < O; ++o) { for (int g = 0; g < groups; ++g) {
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
float r = 0.; float r = 0.;
for (int wh = 0; wh < wH; ++wh) { for (int wh = 0; wh < wH; ++wh) {
@ -151,16 +156,16 @@ void slow_conv_2D(
int ih = ih_base + wh_flip * wt_dilation[0]; int ih = ih_base + wh_flip * wt_dilation[0];
int iw = iw_base + ww_flip * wt_dilation[1]; int iw = iw_base + ww_flip * wt_dilation[1];
const T* wt_ptr_pt = wt_ptr + wh * wt_stride_H + ww * wt_stride_W; const T* wt_ptr_pt =
const T* in_ptr_pt = in_ptr + ih * in_stride_H + iw * in_stride_W; wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
const T* in_ptr_pt =
in_ptr + ih * in_stride_H + iw * in_stride_W;
for (int c = 0; c < C; ++c) { for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) {
r += static_cast<float>(in_ptr_pt[0]) * r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
static_cast<float>(wt_ptr_pt[0]); static_cast<float>(
in_ptr_pt += in_stride_C; wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
wt_ptr_pt += wt_stride_C;
} // c } // c
} // ww } // ww
} // wh } // wh
@ -168,6 +173,7 @@ void slow_conv_2D(
out_ptr += out_stride_O; out_ptr += out_stride_O;
wt_ptr += wt_stride_O; wt_ptr += wt_stride_O;
} // o } // o
} // g
}; };
int jump_h = flip ? -wt_dilation[0] : wt_dilation[0]; int jump_h = flip ? -wt_dilation[0] : wt_dilation[0];
@ -219,7 +225,8 @@ void slow_conv_2D(
int wh_base = base_h[oh % f_out_jump_h]; int wh_base = base_h[oh % f_out_jump_h];
int ww_base = base_w[ow % f_out_jump_w]; int ww_base = base_w[ow % f_out_jump_w];
for (int o = 0; o < O; ++o) { for (int g = 0; g < groups; ++g) {
for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
float r = 0.; float r = 0.;
for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) { for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
@ -239,11 +246,11 @@ void slow_conv_2D(
const T* in_ptr_pt = const T* in_ptr_pt =
in_ptr + ih_dil * in_stride_H + iw_dil * in_stride_W; in_ptr + ih_dil * in_stride_H + iw_dil * in_stride_W;
for (int c = 0; c < C; ++c) { for (int c = g * C_per_group; c < (g + 1) * C_per_group;
r += static_cast<float>(in_ptr_pt[0]) * ++c) {
static_cast<float>(wt_ptr_pt[0]); r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
in_ptr_pt += in_stride_C; static_cast<float>(
wt_ptr_pt += wt_stride_C; wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
} // c } // c
} // ih, iw check } // ih, iw check
@ -254,6 +261,7 @@ void slow_conv_2D(
out_ptr += out_stride_O; out_ptr += out_stride_O;
wt_ptr += wt_stride_O; wt_ptr += wt_stride_O;
} // o } // o
} // g
}; };
int oH_border_0 = 0; int oH_border_0 = 0;

View File

@ -257,15 +257,19 @@ void implicit_gemm_conv_2D_gpu(
const array& wt, const array& wt,
array out, array out,
const MLXConvParams<2>& conv_params) { const MLXConvParams<2>& conv_params) {
const int groups = conv_params.groups;
const int C_per_group = conv_params.C / conv_params.groups;
const int O_per_group = conv_params.O / conv_params.groups;
// Deduce implicit gemm size // Deduce implicit gemm size
int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1]; const int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1];
int implicit_N = conv_params.O; const int implicit_N = O_per_group;
int implicit_K = conv_params.wS[0] * conv_params.wS[1] * conv_params.C; const int implicit_K = conv_params.wS[0] * conv_params.wS[1] * C_per_group;
// Determine block and warp tiles // Determine block and warp tiles
int wm = 2, wn = 2; int wm = 2, wn = 2;
int bm = implicit_M >= 8192 && conv_params.C >= 64 ? 64 : 32; int bm = implicit_M >= 8192 && C_per_group >= 64 ? 64 : 32;
int bn = (bm == 64 || implicit_N >= 64) ? 64 : 32; int bn = (bm == 64 || implicit_N >= 64) ? 64 : 32;
int bk = 16; int bk = 16;
@ -281,15 +285,15 @@ void implicit_gemm_conv_2D_gpu(
// Fix small channel specialization // Fix small channel specialization
int n_channel_specialization = 0; int n_channel_specialization = 0;
int channel_k_iters = ((conv_params.C + bk - 1) / bk); int channel_k_iters = ((C_per_group + bk - 1) / bk);
int gemm_k_iters = conv_params.wS[0] * conv_params.wS[1] * channel_k_iters; int gemm_k_iters = conv_params.wS[0] * conv_params.wS[1] * channel_k_iters;
if (conv_params.C <= 2) { if (C_per_group <= 2) {
gemm_k_iters = (implicit_K + bk - 1) / bk; gemm_k_iters = (implicit_K + bk - 1) / bk;
n_channel_specialization = conv_params.C; n_channel_specialization = C_per_group;
} else if (conv_params.C <= 4) { } else if (C_per_group <= 4) {
gemm_k_iters = ((conv_params.wS[0] * conv_params.wS[1] * 4) + bk - 1) / bk; gemm_k_iters = ((conv_params.wS[0] * conv_params.wS[1] * 4) + bk - 1) / bk;
n_channel_specialization = conv_params.C; n_channel_specialization = C_per_group;
} }
bool small_filter = (!n_channel_specialization) && bool small_filter = (!n_channel_specialization) &&
@ -340,7 +344,7 @@ void implicit_gemm_conv_2D_gpu(
size_t grid_dim_x = tn * tile; size_t grid_dim_x = tn * tile;
MTL::Size group_dims = MTL::Size(32, wn, wm); MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, 1); MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, groups);
// Encode arrays // Encode arrays
compute_encoder.set_input_array(in, 0); compute_encoder.set_input_array(in, 0);
@ -703,6 +707,7 @@ void conv_2D_gpu(
const std::vector<int>& wt_strides, const std::vector<int>& wt_strides,
const std::vector<int>& wt_dilation, const std::vector<int>& wt_dilation,
const std::vector<int>& in_dilation, const std::vector<int>& in_dilation,
const int groups,
bool flip, bool flip,
std::vector<array>& copies) { std::vector<array>& copies) {
// Make conv params // Make conv params
@ -718,12 +723,12 @@ void conv_2D_gpu(
/* const int kdil[NDIM] = */ {wt_dilation[0], wt_dilation[1]}, /* const int kdil[NDIM] = */ {wt_dilation[0], wt_dilation[1]},
/* const int idil[NDIM] = */ {in_dilation[0], in_dilation[1]}, /* const int idil[NDIM] = */ {in_dilation[0], in_dilation[1]},
/* const size_t in_strides[NDIM + 2] = */ /* const size_t in_strides[NDIM + 2] = */
{in.strides()[0], in.strides()[1], in.strides()[2], in.strides()[3]}, {in.strides(0), in.strides(1), in.strides(2), in.strides(3)},
/* const size_t wt_strides[NDIM + 2] = */ /* const size_t wt_strides[NDIM + 2] = */
{wt.strides()[0], wt.strides()[1], wt.strides()[2], wt.strides()[3]}, {wt.strides(0), wt.strides(1), wt.strides(2), wt.strides(3)},
/* const size_t out_strides[NDIM + 2] = */ /* const size_t out_strides[NDIM + 2] = */
{out.strides()[0], out.strides()[1], out.strides()[2], out.strides()[3]}, {out.strides(0), out.strides(1), out.strides(2), out.strides(3)},
/* const int groups = */ 1, /* const int groups = */ groups,
/* const bool flip = */ flip, /* const bool flip = */ flip,
}; };
@ -735,6 +740,18 @@ void conv_2D_gpu(
bool channels_large = (conv_params.C + conv_params.O) >= 512; bool channels_large = (conv_params.C + conv_params.O) >= 512;
bool channels_med = (conv_params.C + conv_params.O) >= 256; bool channels_med = (conv_params.C + conv_params.O) >= 256;
if (groups > 1) {
const int C_per_group = conv_params.C / groups;
const int O_per_group = conv_params.O / groups;
if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) &&
(O_per_group <= 16 || O_per_group % 16 == 0)) {
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
} else {
return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);
}
}
// Direct to winograd conv // Direct to winograd conv
if (!flip && is_stride_one && is_kdil_one && is_idil_one && if (!flip && is_stride_one && is_kdil_one && is_idil_one &&
conv_params.wS[0] == 3 && conv_params.wS[1] == 3 && conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
@ -860,6 +877,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
kernel_strides_, kernel_strides_,
kernel_dilation_, kernel_dilation_,
input_dilation_, input_dilation_,
groups_,
flip_, flip_,
copies); copies);
} }

View File

@ -109,6 +109,7 @@ template <typename T, int N>
bool valid = n < params->N; bool valid = n < params->N;
// Unroll dimensions // Unroll dimensions
int kernel_stride = 1;
for (int i = N - 1; i >= 0; --i) { for (int i = N - 1; i >= 0; --i) {
int os_ = (oS % params->oS[i]); int os_ = (oS % params->oS[i]);
int ws_ = (wS % params->wS[i]); int ws_ = (wS % params->wS[i]);
@ -125,7 +126,8 @@ template <typename T, int N>
oS /= params->oS[i]; oS /= params->oS[i];
wS /= params->wS[i]; wS /= params->wS[i];
out += ws_ * params->str[i]; out += ws_ * kernel_stride;
kernel_stride *= params->wS[i];
} }
if (valid) { if (valid) {

View File

@ -133,9 +133,15 @@ implicit_gemm_conv_2d(
const int c_col = tid_x * BN; const int c_col = tid_x * BN;
const int K = gemm_params->K; const int K = gemm_params->K;
const int N = gemm_params->N; const int N = gemm_params->N;
const int C_per_group = params->C / params->groups;
// Groups
A += tid.z * C_per_group;
B += tid.z * N * K;
C += tid.z * N;
B += c_col * K; B += c_col * K;
C += c_row * N + c_col; C += c_row * (N * params->groups) + c_col;
const int2 offsets_a(0, c_row); const int2 offsets_a(0, c_row);
const int2 offsets_b(0, c_col); const int2 offsets_b(0, c_col);
@ -171,7 +177,8 @@ implicit_gemm_conv_2d(
// Store results to device memory // Store results to device memory
short tgp_bm = min(BM, gemm_params->M - c_row); short tgp_bm = min(BM, gemm_params->M - c_row);
short tgp_bn = min(BN, gemm_params->N - c_col); short tgp_bn = min(BN, gemm_params->N - c_col);
mma_op.store_result_safe(C, N, short2(tgp_bn, tgp_bm)); const int ldc = N * params->groups;
mma_op.store_result_safe(C, ldc, short2(tgp_bn, tgp_bm));
} }
#define instantiate_implicit_conv_2d( \ #define instantiate_implicit_conv_2d( \

View File

@ -3180,9 +3180,9 @@ array conv_general(
bool flip /* = false */, bool flip /* = false */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
// Run checks // Run checks
if (groups != 1 && in.ndim() != 3) { if (groups != 1 && in.ndim() != 3 && in.ndim() != 4) {
throw std::invalid_argument( throw std::invalid_argument(
"[conv] Can only handle groups != 1 in 1D convolutions."); "[conv] Can only handle groups != 1 in 1D or 2D convolutions.");
} }
int spatial_dims = in.ndim() - 2; int spatial_dims = in.ndim() - 2;

View File

@ -123,9 +123,13 @@ class TestConv(mlx_tests.MLXTestCase):
# Groups tests # Groups tests
N, C, O = (4, 32, 64) N, C, O = (4, 32, 64)
iH, kH, stride, padding = (31, 5, 1, 2) for iH, kH, stride, padding in (
(1, 1, 1, 0),
(3, 3, 1, 0),
(31, 5, 5, 2),
):
for group in (1, 2, 4, 8, 16, 32): for group in (1, 2, 4, 8, 16, 32):
run_conv1D(N, C, O, iH, kH, stride=1, padding=1, groups=group, dtype=dtype) run_conv1D(N, C, O, iH, kH, stride, padding, groups=group, dtype=dtype)
# Strided inputs tests # Strided inputs tests
for tpose_in, tpose_wt in ( for tpose_in, tpose_wt in (
@ -291,7 +295,9 @@ class TestConv(mlx_tests.MLXTestCase):
kH, kW = kdim kH, kW = kdim
scale = 1.0 / math.sqrt(kH * kW * C) scale = 1.0 / math.sqrt(kH * kW * C)
in_np = np.random.normal(0.0, scale, (N, iH, iW, C)).astype(np_dtype) in_np = np.random.normal(0.0, scale, (N, iH, iW, C)).astype(np_dtype)
wt_np = np.random.normal(0.0, 1.0, (O, kH, kW, C)).astype(np_dtype) wt_np = np.random.normal(0.0, 1.0, (O, kH, kW, int(C / groups))).astype(
np_dtype
)
in_mx, wt_mx = map(mx.array, (in_np, wt_np)) in_mx, wt_mx = map(mx.array, (in_np, wt_np))
in_pt, wt_pt = map( in_pt, wt_pt = map(
@ -334,6 +340,18 @@ class TestConv(mlx_tests.MLXTestCase):
): ):
run_conv2D(N, C, O, idim, kdim, stride, padding, dtype=dtype) run_conv2D(N, C, O, idim, kdim, stride, padding, dtype=dtype)
# Groups tests
N, C, O = (4, 32, 64)
for idim, kdim, stride, padding in (
((1, 1), (1, 1), (1, 1), (0, 0)),
((3, 3), (3, 1), (1, 1), (0, 0)),
((31, 31), (5, 5), (5, 5), (2, 2)),
):
for group in (1, 2, 4, 8, 16, 32):
run_conv2D(
N, C, O, idim, kdim, stride, padding, groups=group, dtype=dtype
)
@unittest.skipIf(not has_torch, "requires Torch") @unittest.skipIf(not has_torch, "requires Torch")
def test_torch_conv_2D_grad(self): def test_torch_conv_2D_grad(self):
def run_conv2D_grad( def run_conv2D_grad(

View File

@ -3268,22 +3268,22 @@ TEST_CASE("test conv1d") {
float16); float16);
auto expected = array( auto expected = array(
{1.5685, {1.56836,
0.5672, 0.567383,
1.8121, 1.8125,
1.2948, 1.29492,
2.3448, 2.34375,
1.6104, 1.61035,
2.7743, 2.77539,
1.6126, 1.61328,
1.4056, 1.40527,
0.9331, 0.933105,
1.8739, 1.87402,
1.0909}, 1.09082},
{1, 3, 4}); {1, 3, 4});
auto out = conv1d(in, wt, stride, padding, /* dilation= */ 1, groups); auto out = conv1d(in, wt, stride, padding, /* dilation= */ 1, groups);
CHECK(allclose(out, expected, /* rtol = */ 1.0e-3).item<bool>()); CHECK(allclose(out, expected).item<bool>());
} }
{ {
@ -3309,22 +3309,151 @@ TEST_CASE("test conv1d") {
{4, 3, 1}); {4, 3, 1});
auto expected = array( auto expected = array(
{1.0703, {1.07007,
0.7533, 0.753201,
0.7007, 0.700818,
0.4681, 0.468176,
1.1859, 1.18568,
0.9117, 0.91152,
0.9565, 0.956607,
0.6111, 0.611213,
0.6416, 0.641404,
0.5665, 0.566401,
0.9074, 0.907472,
0.0605}, 0.0605397},
{1, 3, 4}); {1, 3, 4});
auto out = conv1d(in, wt, stride, padding, /* dilation= */ 1, groups); auto out = conv1d(in, wt, stride, padding, /* dilation= */ 1, groups);
CHECK(allclose(out, expected, /* rtol = */ 1.0e-3).item<bool>()); CHECK(allclose(out, expected).item<bool>());
}
}
TEST_CASE("test conv2d") {
auto in = array(
{0.57429284,
-0.21628855,
-0.18673691,
-0.3793517,
0.3059678,
-0.8137168,
0.6168841,
-0.26912728},
{1, 2, 2, 2});
std::pair<int, int> kernel{2, 2};
std::pair<int, int> stride{1, 1};
std::pair<int, int> padding{0, 0};
{
int groups = 1;
auto wt = array(
{0.3190391, -0.24937038, 1.4621079, -2.0601406, -0.3224172,
-0.38405436, 1.1337694, -1.0998913, -0.1724282, -0.8778584,
0.04221375, 0.58281523, -1.1006192, 1.1447237, 0.9015907,
0.50249434, 0.90085596, -0.68372786, -0.12289023, -0.93576944,
-0.26788807, 0.53035545, -0.69166076, -0.39675352, -0.6871727,
-0.84520566, -0.6712461, -0.0126646, -1.1173104, 0.2344157,
1.6598022, 0.74204415},
{4, 2, 2, 2});
auto expected =
array({1.9549234, -0.98542136, 0.2097499, 0.20991313}, {1, 1, 1, 4});
auto out = conv2d(in, wt, stride, padding, /* dilation= */ {1, 1}, groups);
CHECK(allclose(out, expected).item<bool>());
}
{
int groups = 2;
auto wt = array(
{0.3190391,
-0.24937038,
1.46210794,
-2.06014071,
-0.3224172,
-0.38405435,
1.13376944,
-1.09989127,
-0.17242821,
-0.87785842,
0.04221375,
0.58281521,
-1.10061918,
1.14472371,
0.90159072,
0.50249434},
{4, 2, 2, 1});
auto expected = array(
{-0.59372161, -0.44505326, 0.17910982, -1.06507601}, {1, 1, 1, 4});
auto out = conv2d(in, wt, stride, padding, /* dilation= */ {1, 1}, groups);
CHECK(allclose(out, expected).item<bool>());
}
{
in = array(
{0.57429284,
-0.21628855,
-0.18673691,
-0.3793517,
0.3059678,
-0.8137168,
0.6168841,
-0.26912728,
0.57429284,
-0.21628855,
-0.18673691,
-0.3793517,
0.3059678,
-0.8137168,
0.6168841,
-0.26912728},
{2, 2, 2, 2});
int groups = 2;
auto wt = array(
{0.3190391,
-0.24937038,
1.46210794,
-2.06014071,
-0.3224172,
-0.38405435,
1.13376944,
-1.09989127,
-0.17242821,
-0.87785842,
0.04221375,
0.58281521,
-1.10061918,
1.14472371,
0.90159072,
0.50249434},
{4, 2, 2, 1});
auto expected = array(
{-0.59372161, -0.44505326, 0.17910982, -1.06507601}, {1, 1, 1, 4});
auto out = conv2d(in, wt, stride, padding, /* dilation= */ {1, 1}, groups);
CHECK(allclose(out, expected).item<bool>());
} }
} }