From 9401507336287634eab983c8a0ac13b96a6ff224 Mon Sep 17 00:00:00 2001 From: Rifur13 Date: Wed, 22 May 2024 23:01:44 -0400 Subject: [PATCH] Add groups to 2-D convolutions (#1129) * Added groups to 2-D convolutions. Only implemented for **some** specializations. Also fixed 1D grouped convs with different kernel strides and added more tests. * fix channels condition --- ACKNOWLEDGMENTS.md | 1 + benchmarks/python/conv_bench.py | 69 ++++--- mlx/backend/common/conv.cpp | 112 ++++++----- mlx/backend/metal/conv.cpp | 46 +++-- mlx/backend/metal/kernels/conv.metal | 4 +- .../steel/conv/kernels/steel_conv.metal | 11 +- mlx/ops.cpp | 4 +- python/tests/test_conv.py | 26 ++- tests/ops_tests.cpp | 181 +++++++++++++++--- 9 files changed, 322 insertions(+), 132 deletions(-) diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index 05dca2768..598161f78 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -16,6 +16,7 @@ MLX was developed with contributions from the following individuals: - Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays. - Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention` - AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`. +- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions. diff --git a/benchmarks/python/conv_bench.py b/benchmarks/python/conv_bench.py index f052487d9..7ca4d8a7c 100644 --- a/benchmarks/python/conv_bench.py +++ b/benchmarks/python/conv_bench.py @@ -28,11 +28,11 @@ def bench(f, a, b): return (e - s) * 1e-9 -def make_mx_conv_2D(strides=(1, 1), padding=(0, 0)): +def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1): def mx_conv_2D(a, b): ys = [] for i in range(N_iter_func): - y = mx.conv2d(a, b, stride=strides, padding=padding) + y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups) ys.append(y) mx.eval(ys) return ys @@ -40,12 +40,12 @@ def make_mx_conv_2D(strides=(1, 1), padding=(0, 0)): return mx_conv_2D -def make_pt_conv_2D(strides=(1, 1), padding=(0, 0)): +def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1): @torch.no_grad() def pt_conv_2D(a, b): ys = [] for i in range(N_iter_func): - y = torch.conv2d(a, b, stride=strides, padding=padding) + y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups) ys.append(y) torch.mps.synchronize() return ys @@ -53,11 +53,13 @@ def make_pt_conv_2D(strides=(1, 1), padding=(0, 0)): return pt_conv_2D -def bench_shape(N, H, W, C, kH, kW, O, strides, padding, np_dtype): +def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype): scale = 1.0 / math.sqrt(kH * kH * C) a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype) - b_np = np.random.uniform(-scale, scale, (O, kH, kW, C)).astype(np_dtype) + b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype( + np_dtype + ) a_mx = mx.array(a_np) b_mx = mx.array(b_np) @@ -67,15 +69,15 @@ def bench_shape(N, H, W, C, kH, kW, O, strides, padding, np_dtype): torch.mps.synchronize() - f_mx = make_mx_conv_2D(strides, padding) - f_pt = make_pt_conv_2D(strides, padding) + f_mx = make_mx_conv_2D(strides, padding, groups) + f_pt = make_pt_conv_2D(strides, padding, groups) time_torch = bench(f_pt, a_pt, b_pt) time_mlx = bench(f_mx, a_mx, b_mx) - out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding) + out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups) out_pt = torch.conv2d( - a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding + a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups ) out_pt = torch.permute(out_pt, (0, 2, 3, 1)) out_pt = out_pt.numpy(force=True) @@ -84,7 +86,7 @@ def bench_shape(N, H, W, C, kH, kW, O, strides, padding, np_dtype): if not np.allclose(out_pt, out_mx, atol=atol): print( - f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}" + f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}" ) return time_mlx, time_torch @@ -95,35 +97,40 @@ if __name__ == "__main__": dtypes = ("float32",) shapes = ( - (4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2)), - (4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2)), - (4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2)), - (4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2)), - (4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2)), - (4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2)), - (4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2)), - (4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2)), - (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2)), - (4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2)), - (4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2)), - (4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2)), - (4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2)), - (4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2)), - (4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2)), - (4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2)), + (4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1), + (4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1), + (4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1), + (4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1), + (4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1), + (4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1), + (4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1), + (4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1), + (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1), + (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 2), + (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 16), + (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 64), + (4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1), + (4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1), + (4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1), + (4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1), + (4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1), + (4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1), + (4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1), ) for dtype in dtypes: - print("(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, diff%") - for N, H, W, C, kH, kW, O, strides, padding in shapes: + print( + "(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%" + ) + for N, H, W, C, kH, kW, O, strides, padding, groups in shapes: np_dtype = getattr(np, dtype) time_mlx, time_torch = bench_shape( - N, H, W, C, kH, kW, O, strides, padding, np_dtype + N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype ) diff = time_torch / time_mlx - 1.0 print( - f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {100. * diff:+5.2f}%" + f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%" ) if time_mlx >= 2.0 * time_torch: print("ATTENTION ^^^^^^^") diff --git a/mlx/backend/common/conv.cpp b/mlx/backend/common/conv.cpp index 0b19e60e3..79bc3c4a1 100644 --- a/mlx/backend/common/conv.cpp +++ b/mlx/backend/common/conv.cpp @@ -111,13 +111,17 @@ void slow_conv_2D( const int N = in.shape(0); // Batch size, should be the same as out.shape(0) const int iH = 1 + in_dilation[0] * (in.shape(1) - 1); // Input spatial dim const int iW = 1 + in_dilation[1] * (in.shape(2) - 1); // Input spatial dim + const int C = in.shape(3); // In channels const int oH = out.shape(1); // Output spatial dim const int oW = out.shape(2); // Output spatial dim const int O = wt.shape(0); // Out channels - const int C = wt.shape(3); // In channels const int wH = wt.shape(1); // Weight spatial dim const int wW = wt.shape(2); // Weight spatial dim + const int groups = C / wt.shape(3); + const int C_per_group = wt.shape(3); + const int O_per_group = O / groups; + const size_t in_stride_N = in.strides()[0]; const size_t in_stride_H = in.strides()[1]; const size_t in_stride_W = in.strides()[2]; @@ -141,33 +145,35 @@ void slow_conv_2D( int ih_base = oh * wt_strides[0] - padding[0]; int iw_base = ow * wt_strides[1] - padding[1]; - for (int o = 0; o < O; ++o) { - float r = 0.; + for (int g = 0; g < groups; ++g) { + for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) { + float r = 0.; - for (int wh = 0; wh < wH; ++wh) { - for (int ww = 0; ww < wW; ++ww) { - int wh_flip = flip ? wH - wh - 1 : wh; - int ww_flip = flip ? wW - ww - 1 : ww; - int ih = ih_base + wh_flip * wt_dilation[0]; - int iw = iw_base + ww_flip * wt_dilation[1]; + for (int wh = 0; wh < wH; ++wh) { + for (int ww = 0; ww < wW; ++ww) { + int wh_flip = flip ? wH - wh - 1 : wh; + int ww_flip = flip ? wW - ww - 1 : ww; + int ih = ih_base + wh_flip * wt_dilation[0]; + int iw = iw_base + ww_flip * wt_dilation[1]; - const T* wt_ptr_pt = wt_ptr + wh * wt_stride_H + ww * wt_stride_W; - const T* in_ptr_pt = in_ptr + ih * in_stride_H + iw * in_stride_W; + const T* wt_ptr_pt = + wt_ptr + wh * wt_stride_H + ww * wt_stride_W; + const T* in_ptr_pt = + in_ptr + ih * in_stride_H + iw * in_stride_W; - for (int c = 0; c < C; ++c) { - r += static_cast(in_ptr_pt[0]) * - static_cast(wt_ptr_pt[0]); - in_ptr_pt += in_stride_C; - wt_ptr_pt += wt_stride_C; - } // c + for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) { + r += static_cast(in_ptr_pt[c * in_stride_C]) * + static_cast( + wt_ptr_pt[(c % C_per_group) * wt_stride_C]); + } // c + } // ww + } // wh - } // ww - } // wh - - out_ptr[0] = static_cast(r); - out_ptr += out_stride_O; - wt_ptr += wt_stride_O; - } // o + out_ptr[0] = static_cast(r); + out_ptr += out_stride_O; + wt_ptr += wt_stride_O; + } // o + } // g }; int jump_h = flip ? -wt_dilation[0] : wt_dilation[0]; @@ -219,41 +225,43 @@ void slow_conv_2D( int wh_base = base_h[oh % f_out_jump_h]; int ww_base = base_w[ow % f_out_jump_w]; - for (int o = 0; o < O; ++o) { - float r = 0.; + for (int g = 0; g < groups; ++g) { + for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) { + float r = 0.; - for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) { - for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) { - int wh_flip = flip ? wH - wh - 1 : wh; - int ww_flip = flip ? wW - ww - 1 : ww; - int ih = ih_base + wh_flip * wt_dilation[0]; - int iw = iw_base + ww_flip * wt_dilation[1]; + for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) { + for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) { + int wh_flip = flip ? wH - wh - 1 : wh; + int ww_flip = flip ? wW - ww - 1 : ww; + int ih = ih_base + wh_flip * wt_dilation[0]; + int iw = iw_base + ww_flip * wt_dilation[1]; - if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) { - const T* wt_ptr_pt = - wt_ptr + wh * wt_stride_H + ww * wt_stride_W; + if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) { + const T* wt_ptr_pt = + wt_ptr + wh * wt_stride_H + ww * wt_stride_W; - int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih; - int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw; + int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih; + int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw; - const T* in_ptr_pt = - in_ptr + ih_dil * in_stride_H + iw_dil * in_stride_W; + const T* in_ptr_pt = + in_ptr + ih_dil * in_stride_H + iw_dil * in_stride_W; - for (int c = 0; c < C; ++c) { - r += static_cast(in_ptr_pt[0]) * - static_cast(wt_ptr_pt[0]); - in_ptr_pt += in_stride_C; - wt_ptr_pt += wt_stride_C; - } // c + for (int c = g * C_per_group; c < (g + 1) * C_per_group; + ++c) { + r += static_cast(in_ptr_pt[c * in_stride_C]) * + static_cast( + wt_ptr_pt[(c % C_per_group) * wt_stride_C]); + } // c - } // ih, iw check - } // ww - } // wh + } // ih, iw check + } // ww + } // wh - out_ptr[0] = static_cast(r); - out_ptr += out_stride_O; - wt_ptr += wt_stride_O; - } // o + out_ptr[0] = static_cast(r); + out_ptr += out_stride_O; + wt_ptr += wt_stride_O; + } // o + } // g }; int oH_border_0 = 0; diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index 7fcb6592e..03fda47e4 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -257,15 +257,19 @@ void implicit_gemm_conv_2D_gpu( const array& wt, array out, const MLXConvParams<2>& conv_params) { + const int groups = conv_params.groups; + const int C_per_group = conv_params.C / conv_params.groups; + const int O_per_group = conv_params.O / conv_params.groups; + // Deduce implicit gemm size - int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1]; - int implicit_N = conv_params.O; - int implicit_K = conv_params.wS[0] * conv_params.wS[1] * conv_params.C; + const int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1]; + const int implicit_N = O_per_group; + const int implicit_K = conv_params.wS[0] * conv_params.wS[1] * C_per_group; // Determine block and warp tiles int wm = 2, wn = 2; - int bm = implicit_M >= 8192 && conv_params.C >= 64 ? 64 : 32; + int bm = implicit_M >= 8192 && C_per_group >= 64 ? 64 : 32; int bn = (bm == 64 || implicit_N >= 64) ? 64 : 32; int bk = 16; @@ -281,15 +285,15 @@ void implicit_gemm_conv_2D_gpu( // Fix small channel specialization int n_channel_specialization = 0; - int channel_k_iters = ((conv_params.C + bk - 1) / bk); + int channel_k_iters = ((C_per_group + bk - 1) / bk); int gemm_k_iters = conv_params.wS[0] * conv_params.wS[1] * channel_k_iters; - if (conv_params.C <= 2) { + if (C_per_group <= 2) { gemm_k_iters = (implicit_K + bk - 1) / bk; - n_channel_specialization = conv_params.C; - } else if (conv_params.C <= 4) { + n_channel_specialization = C_per_group; + } else if (C_per_group <= 4) { gemm_k_iters = ((conv_params.wS[0] * conv_params.wS[1] * 4) + bk - 1) / bk; - n_channel_specialization = conv_params.C; + n_channel_specialization = C_per_group; } bool small_filter = (!n_channel_specialization) && @@ -340,7 +344,7 @@ void implicit_gemm_conv_2D_gpu( size_t grid_dim_x = tn * tile; MTL::Size group_dims = MTL::Size(32, wn, wm); - MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, 1); + MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, groups); // Encode arrays compute_encoder.set_input_array(in, 0); @@ -703,6 +707,7 @@ void conv_2D_gpu( const std::vector& wt_strides, const std::vector& wt_dilation, const std::vector& in_dilation, + const int groups, bool flip, std::vector& copies) { // Make conv params @@ -718,12 +723,12 @@ void conv_2D_gpu( /* const int kdil[NDIM] = */ {wt_dilation[0], wt_dilation[1]}, /* const int idil[NDIM] = */ {in_dilation[0], in_dilation[1]}, /* const size_t in_strides[NDIM + 2] = */ - {in.strides()[0], in.strides()[1], in.strides()[2], in.strides()[3]}, + {in.strides(0), in.strides(1), in.strides(2), in.strides(3)}, /* const size_t wt_strides[NDIM + 2] = */ - {wt.strides()[0], wt.strides()[1], wt.strides()[2], wt.strides()[3]}, + {wt.strides(0), wt.strides(1), wt.strides(2), wt.strides(3)}, /* const size_t out_strides[NDIM + 2] = */ - {out.strides()[0], out.strides()[1], out.strides()[2], out.strides()[3]}, - /* const int groups = */ 1, + {out.strides(0), out.strides(1), out.strides(2), out.strides(3)}, + /* const int groups = */ groups, /* const bool flip = */ flip, }; @@ -735,6 +740,18 @@ void conv_2D_gpu( bool channels_large = (conv_params.C + conv_params.O) >= 512; bool channels_med = (conv_params.C + conv_params.O) >= 256; + if (groups > 1) { + const int C_per_group = conv_params.C / groups; + const int O_per_group = conv_params.O / groups; + + if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) && + (O_per_group <= 16 || O_per_group % 16 == 0)) { + return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params); + } else { + return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params); + } + } + // Direct to winograd conv if (!flip && is_stride_one && is_kdil_one && is_idil_one && conv_params.wS[0] == 3 && conv_params.wS[1] == 3 && @@ -860,6 +877,7 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out) { kernel_strides_, kernel_dilation_, input_dilation_, + groups_, flip_, copies); } diff --git a/mlx/backend/metal/kernels/conv.metal b/mlx/backend/metal/kernels/conv.metal index 4c65a3677..92e91505d 100644 --- a/mlx/backend/metal/kernels/conv.metal +++ b/mlx/backend/metal/kernels/conv.metal @@ -109,6 +109,7 @@ template bool valid = n < params->N; // Unroll dimensions + int kernel_stride = 1; for (int i = N - 1; i >= 0; --i) { int os_ = (oS % params->oS[i]); int ws_ = (wS % params->wS[i]); @@ -125,7 +126,8 @@ template oS /= params->oS[i]; wS /= params->wS[i]; - out += ws_ * params->str[i]; + out += ws_ * kernel_stride; + kernel_stride *= params->wS[i]; } if (valid) { diff --git a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal index 39953c2be..ee5bcb285 100644 --- a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +++ b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal @@ -133,9 +133,15 @@ implicit_gemm_conv_2d( const int c_col = tid_x * BN; const int K = gemm_params->K; const int N = gemm_params->N; + const int C_per_group = params->C / params->groups; + + // Groups + A += tid.z * C_per_group; + B += tid.z * N * K; + C += tid.z * N; B += c_col * K; - C += c_row * N + c_col; + C += c_row * (N * params->groups) + c_col; const int2 offsets_a(0, c_row); const int2 offsets_b(0, c_col); @@ -171,7 +177,8 @@ implicit_gemm_conv_2d( // Store results to device memory short tgp_bm = min(BM, gemm_params->M - c_row); short tgp_bn = min(BN, gemm_params->N - c_col); - mma_op.store_result_safe(C, N, short2(tgp_bn, tgp_bm)); + const int ldc = N * params->groups; + mma_op.store_result_safe(C, ldc, short2(tgp_bn, tgp_bm)); } #define instantiate_implicit_conv_2d( \ diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 22680dd7f..ae0c3a258 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3180,9 +3180,9 @@ array conv_general( bool flip /* = false */, StreamOrDevice s /* = {} */) { // Run checks - if (groups != 1 && in.ndim() != 3) { + if (groups != 1 && in.ndim() != 3 && in.ndim() != 4) { throw std::invalid_argument( - "[conv] Can only handle groups != 1 in 1D convolutions."); + "[conv] Can only handle groups != 1 in 1D or 2D convolutions."); } int spatial_dims = in.ndim() - 2; diff --git a/python/tests/test_conv.py b/python/tests/test_conv.py index 1111ce80a..754f1727e 100644 --- a/python/tests/test_conv.py +++ b/python/tests/test_conv.py @@ -123,9 +123,13 @@ class TestConv(mlx_tests.MLXTestCase): # Groups tests N, C, O = (4, 32, 64) - iH, kH, stride, padding = (31, 5, 1, 2) - for group in (1, 2, 4, 8, 16, 32): - run_conv1D(N, C, O, iH, kH, stride=1, padding=1, groups=group, dtype=dtype) + for iH, kH, stride, padding in ( + (1, 1, 1, 0), + (3, 3, 1, 0), + (31, 5, 5, 2), + ): + for group in (1, 2, 4, 8, 16, 32): + run_conv1D(N, C, O, iH, kH, stride, padding, groups=group, dtype=dtype) # Strided inputs tests for tpose_in, tpose_wt in ( @@ -291,7 +295,9 @@ class TestConv(mlx_tests.MLXTestCase): kH, kW = kdim scale = 1.0 / math.sqrt(kH * kW * C) in_np = np.random.normal(0.0, scale, (N, iH, iW, C)).astype(np_dtype) - wt_np = np.random.normal(0.0, 1.0, (O, kH, kW, C)).astype(np_dtype) + wt_np = np.random.normal(0.0, 1.0, (O, kH, kW, int(C / groups))).astype( + np_dtype + ) in_mx, wt_mx = map(mx.array, (in_np, wt_np)) in_pt, wt_pt = map( @@ -334,6 +340,18 @@ class TestConv(mlx_tests.MLXTestCase): ): run_conv2D(N, C, O, idim, kdim, stride, padding, dtype=dtype) + # Groups tests + N, C, O = (4, 32, 64) + for idim, kdim, stride, padding in ( + ((1, 1), (1, 1), (1, 1), (0, 0)), + ((3, 3), (3, 1), (1, 1), (0, 0)), + ((31, 31), (5, 5), (5, 5), (2, 2)), + ): + for group in (1, 2, 4, 8, 16, 32): + run_conv2D( + N, C, O, idim, kdim, stride, padding, groups=group, dtype=dtype + ) + @unittest.skipIf(not has_torch, "requires Torch") def test_torch_conv_2D_grad(self): def run_conv2D_grad( diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index c1bbeb051..7aa3a3450 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -3268,22 +3268,22 @@ TEST_CASE("test conv1d") { float16); auto expected = array( - {1.5685, - 0.5672, - 1.8121, - 1.2948, - 2.3448, - 1.6104, - 2.7743, - 1.6126, - 1.4056, - 0.9331, - 1.8739, - 1.0909}, + {1.56836, + 0.567383, + 1.8125, + 1.29492, + 2.34375, + 1.61035, + 2.77539, + 1.61328, + 1.40527, + 0.933105, + 1.87402, + 1.09082}, {1, 3, 4}); auto out = conv1d(in, wt, stride, padding, /* dilation= */ 1, groups); - CHECK(allclose(out, expected, /* rtol = */ 1.0e-3).item()); + CHECK(allclose(out, expected).item()); } { @@ -3309,22 +3309,151 @@ TEST_CASE("test conv1d") { {4, 3, 1}); auto expected = array( - {1.0703, - 0.7533, - 0.7007, - 0.4681, - 1.1859, - 0.9117, - 0.9565, - 0.6111, - 0.6416, - 0.5665, - 0.9074, - 0.0605}, + {1.07007, + 0.753201, + 0.700818, + 0.468176, + 1.18568, + 0.91152, + 0.956607, + 0.611213, + 0.641404, + 0.566401, + 0.907472, + 0.0605397}, {1, 3, 4}); auto out = conv1d(in, wt, stride, padding, /* dilation= */ 1, groups); - CHECK(allclose(out, expected, /* rtol = */ 1.0e-3).item()); + CHECK(allclose(out, expected).item()); + } +} + +TEST_CASE("test conv2d") { + auto in = array( + {0.57429284, + -0.21628855, + -0.18673691, + -0.3793517, + + 0.3059678, + -0.8137168, + 0.6168841, + -0.26912728}, + {1, 2, 2, 2}); + + std::pair kernel{2, 2}; + std::pair stride{1, 1}; + std::pair padding{0, 0}; + + { + int groups = 1; + + auto wt = array( + {0.3190391, -0.24937038, 1.4621079, -2.0601406, -0.3224172, + -0.38405436, 1.1337694, -1.0998913, -0.1724282, -0.8778584, + 0.04221375, 0.58281523, -1.1006192, 1.1447237, 0.9015907, + 0.50249434, 0.90085596, -0.68372786, -0.12289023, -0.93576944, + -0.26788807, 0.53035545, -0.69166076, -0.39675352, -0.6871727, + -0.84520566, -0.6712461, -0.0126646, -1.1173104, 0.2344157, + 1.6598022, 0.74204415}, + {4, 2, 2, 2}); + + auto expected = + array({1.9549234, -0.98542136, 0.2097499, 0.20991313}, {1, 1, 1, 4}); + auto out = conv2d(in, wt, stride, padding, /* dilation= */ {1, 1}, groups); + CHECK(allclose(out, expected).item()); + } + + { + int groups = 2; + auto wt = array( + {0.3190391, + -0.24937038, + + 1.46210794, + -2.06014071, + + -0.3224172, + -0.38405435, + + 1.13376944, + -1.09989127, + + -0.17242821, + -0.87785842, + + 0.04221375, + 0.58281521, + + -1.10061918, + 1.14472371, + + 0.90159072, + 0.50249434}, + {4, 2, 2, 1}); + + auto expected = array( + {-0.59372161, -0.44505326, 0.17910982, -1.06507601}, {1, 1, 1, 4}); + + auto out = conv2d(in, wt, stride, padding, /* dilation= */ {1, 1}, groups); + CHECK(allclose(out, expected).item()); + } + + { + in = array( + {0.57429284, + -0.21628855, + -0.18673691, + -0.3793517, + + 0.3059678, + -0.8137168, + 0.6168841, + -0.26912728, + + 0.57429284, + -0.21628855, + -0.18673691, + -0.3793517, + + 0.3059678, + -0.8137168, + 0.6168841, + -0.26912728}, + {2, 2, 2, 2}); + + int groups = 2; + auto wt = array( + {0.3190391, + -0.24937038, + + 1.46210794, + -2.06014071, + + -0.3224172, + -0.38405435, + + 1.13376944, + -1.09989127, + + -0.17242821, + -0.87785842, + + 0.04221375, + 0.58281521, + + -1.10061918, + 1.14472371, + + 0.90159072, + 0.50249434}, + {4, 2, 2, 1}); + + auto expected = array( + {-0.59372161, -0.44505326, 0.17910982, -1.06507601}, {1, 1, 1, 4}); + + auto out = conv2d(in, wt, stride, padding, /* dilation= */ {1, 1}, groups); + CHECK(allclose(out, expected).item()); } }