From 8777fd104f7c72d32cbb7ccf92754560ea35e7fa Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Thu, 3 Apr 2025 09:42:04 -0700 Subject: [PATCH] Depthwise Conv2D optimization (#2036) - Add new specialized kernel for small kernel (kernels size <= 7), small strides (strides <= 2) depthwise 2d convolutions - Add related tests --- mlx/backend/metal/conv.cpp | 72 +++++++++++++++- mlx/backend/metal/kernels/conv.metal | 122 +++++++++++++++++++++++++++ python/tests/test_conv.py | 42 ++++++++- 3 files changed, 232 insertions(+), 4 deletions(-) diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index c4803a380..9075ea4c5 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -712,6 +712,65 @@ void winograd_conv_2D_gpu( } } +void depthwise_conv_2D_gpu( + const Stream& s, + metal::Device& d, + const array& in, + const array& wt, + array out, + const MLXConvParams<2>& conv_params) { + std::ostringstream kname; + kname << "depthwise_conv_2d_" << type_to_name(out); + std::string base_name = kname.str(); + + const int N = conv_params.N; + const int ker_h = conv_params.wS[0]; + const int ker_w = conv_params.wS[1]; + const int str_h = conv_params.str[0]; + const int str_w = conv_params.str[1]; + const int tc = 8; + const int tw = 8; + const int th = 4; + const bool do_flip = conv_params.flip; + + metal::MTLFCList func_consts = { + {&ker_h, MTL::DataType::DataTypeInt, 00}, + {&ker_w, MTL::DataType::DataTypeInt, 01}, + {&str_h, MTL::DataType::DataTypeInt, 10}, + {&str_w, MTL::DataType::DataTypeInt, 11}, + {&th, MTL::DataType::DataTypeInt, 100}, + {&tw, MTL::DataType::DataTypeInt, 101}, + {&do_flip, MTL::DataType::DataTypeBool, 200}, + }; + + // clang-format off + kname << "_ker_h_" << ker_h + << "_ker_w_" << ker_w + << "_str_h_" << str_h + << "_str_w_" << str_w + << "_tgp_h_" << th + << "_tgp_w_" << tw + << "_do_flip_" << (do_flip ? 't' : 'n'); // clang-format on + + std::string hash_name = kname.str(); + + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts); + compute_encoder.set_compute_pipeline_state(kernel); + + compute_encoder.set_input_array(in, 0); + compute_encoder.set_input_array(wt, 1); + compute_encoder.set_output_array(out, 2); + + compute_encoder.set_bytes(conv_params, 3); + + MTL::Size group_dims = MTL::Size(tc, tw, th); + MTL::Size grid_dims = MTL::Size( + conv_params.C / tc, conv_params.oS[1] / tw, (conv_params.oS[0] / th) * N); + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + void conv_2D_gpu( const Stream& s, metal::Device& d, @@ -754,11 +813,20 @@ void conv_2D_gpu( bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1; bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1; - if (groups > 1) { + if (is_idil_one && groups > 1) { const int C_per_group = conv_params.C / groups; const int O_per_group = conv_params.O / groups; - if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) && + if (C_per_group == 1 && O_per_group == 1 && is_kdil_one && + conv_params.wS[0] <= 7 && conv_params.wS[1] <= 7 && + conv_params.str[0] <= 2 && conv_params.str[1] <= 2 && + conv_params.oS[0] % 8 == 0 && conv_params.oS[1] % 8 == 0 && + conv_params.wt_strides[1] == conv_params.wS[1] && + conv_params.C % 16 == 0 && conv_params.C == conv_params.O) { + return depthwise_conv_2D_gpu(s, d, in, wt, out, conv_params); + } + + if ((C_per_group <= 4 || C_per_group % 16 == 0) && (O_per_group <= 16 || O_per_group % 16 == 0)) { return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params); } else { diff --git a/mlx/backend/metal/kernels/conv.metal b/mlx/backend/metal/kernels/conv.metal index 13ee239dc..620352144 100644 --- a/mlx/backend/metal/kernels/conv.metal +++ b/mlx/backend/metal/kernels/conv.metal @@ -275,6 +275,128 @@ instantiate_naive_conv_2d_blocks(float32, float); instantiate_naive_conv_2d_blocks(float16, half); instantiate_naive_conv_2d_blocks(bfloat16, bfloat16_t); +/////////////////////////////////////////////////////////////////////////////// +/// Depthwise convolution kernels +/////////////////////////////////////////////////////////////////////////////// + +constant int ker_h [[function_constant(00)]]; +constant int ker_w [[function_constant(01)]]; +constant int str_h [[function_constant(10)]]; +constant int str_w [[function_constant(11)]]; +constant int tgp_h [[function_constant(100)]]; +constant int tgp_w [[function_constant(101)]]; +constant bool do_flip [[function_constant(200)]]; + +constant int span_h = tgp_h * str_h + ker_h - 1; +constant int span_w = tgp_w * str_w + ker_w - 1; +constant int span_hw = span_h * span_w; + +template +[[kernel]] void depthwise_conv_2d( + const device T* in [[buffer(0)]], + const device T* wt [[buffer(1)]], + device T* out [[buffer(2)]], + const constant MLXConvParams<2>& params [[buffer(3)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 gid [[thread_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr int tc = 8; + constexpr int tw = 8; + constexpr int th = 4; + + constexpr int c_per_thr = 8; + + constexpr int TGH = th * 2 + 6; + constexpr int TGW = tw * 2 + 6; + constexpr int TGC = tc; + + threadgroup T ins[TGH * TGW * TGC]; + + const int n_tgblocks_h = params.oS[0] / th; + const int n = tid.z / n_tgblocks_h; + const int tghid = tid.z % n_tgblocks_h; + const int oh = tghid * th + lid.z; + const int ow = gid.y; + const int c = gid.x; + + in += n * params.in_strides[0]; + + // Load in + { + constexpr int n_threads = th * tw * tc; + const int tg_oh = (tghid * th) * str_h - params.pad[0]; + const int tg_ow = (tid.y * tw) * str_w - params.pad[1]; + const int tg_c = tid.x * tc; + + const int thread_idx = simd_gid * 32 + simd_lid; + constexpr int thr_per_hw = tc / c_per_thr; + constexpr int hw_per_group = n_threads / thr_per_hw; + + const int thr_c = thread_idx % thr_per_hw; + const int thr_hw = thread_idx / thr_per_hw; + + for (int hw = thr_hw; hw < span_hw; hw += hw_per_group) { + const int h = hw / span_w; + const int w = hw % span_w; + + const int ih = tg_oh + h; + const int iw = tg_ow + w; + + const int in_s_offset = h * span_w * TGC + w * TGC; + + if (ih >= 0 && ih < params.iS[0] && iw >= 0 && iw < params.iS[1]) { + const auto in_load = + in + ih * params.in_strides[1] + iw * params.in_strides[2] + tg_c; + + MLX_MTL_PRAGMA_UNROLL + for (int cc = 0; cc < c_per_thr; ++cc) { + ins[in_s_offset + c_per_thr * thr_c + cc] = + in_load[c_per_thr * thr_c + cc]; + } + } else { + MLX_MTL_PRAGMA_UNROLL + for (int cc = 0; cc < c_per_thr; ++cc) { + ins[in_s_offset + c_per_thr * thr_c + cc] = T(0); + } + } + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + wt += c * params.wt_strides[0]; + + const auto ins_ptr = + &ins[lid.z * str_h * span_w * TGC + lid.y * str_w * TGC + lid.x]; + float o = 0.; + for (int h = 0; h < ker_h; ++h) { + for (int w = 0; w < ker_w; ++w) { + int wt_h = h; + int wt_w = w; + if (do_flip) { + wt_h = ker_h - h - 1; + wt_w = ker_w - w - 1; + } + auto inv = ins_ptr[h * span_w * TGC + w * TGC]; + auto wtv = wt[wt_h * ker_w + wt_w]; + o += inv * wtv; + } + } + threadgroup_barrier(mem_flags::mem_none); + + out += n * params.out_strides[0] + oh * params.out_strides[1] + + ow * params.out_strides[2]; + out[c] = static_cast(o); +} + +#define instantiate_depthconv2d(iname, itype) \ + instantiate_kernel("depthwise_conv_2d_" #iname, depthwise_conv_2d, itype) + +instantiate_depthconv2d(float32, float); +instantiate_depthconv2d(float16, half); +instantiate_depthconv2d(bfloat16, bfloat16_t); + /////////////////////////////////////////////////////////////////////////////// /// Winograd kernels /////////////////////////////////////////////////////////////////////////////// diff --git a/python/tests/test_conv.py b/python/tests/test_conv.py index 9dd8fd140..671c86a32 100644 --- a/python/tests/test_conv.py +++ b/python/tests/test_conv.py @@ -707,9 +707,11 @@ class TestConv(mlx_tests.MLXTestCase): flip=flip, np_dtype=np_dtype, ): + np.random.seed(0) scale = 1.0 / math.sqrt(np.prod(wt_shape[1:])) - in_np = np.random.normal(0.0, scale, in_shape).astype(np_dtype) - wt_np = np.random.normal(0.0, scale, wt_shape).astype(np_dtype) + scale = min(0.3, scale) + in_np = np.random.normal(0, scale, in_shape).astype(np_dtype) + wt_np = np.random.normal(0, scale, wt_shape).astype(np_dtype) in_mx, wt_mx = map(mx.array, (in_np, wt_np)) @@ -1050,6 +1052,42 @@ class TestConv(mlx_tests.MLXTestCase): y2 = mx.conv2d(x, w, (1, 1), (1, 1), (1, 1), 1) self.assertTrue(mx.allclose(y1, y2)) + @unittest.skipIf(not has_torch, "requires Torch") + def test_torch_conv_depthwise(self): + + # fmt: off + shapes = ( + # N, H, W, C kH, kW, O, strides, padding, groups + ( 2, 16, 16, 32, 1, 1, 32, (2, 2), (1, 1), 32), + ( 1, 16, 16, 32, 3, 3, 32, (2, 2), (1, 1), 32), + ( 1, 32, 32, 32, 7, 7, 32, (1, 1), (3, 3), 32), + ( 3, 32, 32, 32, 5, 5, 32, (1, 2), (0, 0), 32), + ( 1, 32, 32, 32, 7, 7, 32, (2, 1), (1, 3), 32), + ) + # fmt: on + + dtypes = [np.float32] + if mx.default_device() == mx.gpu: + dtypes += [np.float16] + + for N, H, W, C, kH, kW, O, strides, padding, groups in shapes: + for dtype in dtypes: + for flip in [False, True]: + Cw = C // groups + + self.__conv_general_test( + (N, H, W, C), + (O, kH, kW, Cw), + strides, + padding, + kernel_dilation=1, + input_dilation=1, + groups=groups, + flip=flip, + np_dtype=dtype, + atol=2e-5 if dtype == np.float32 else 5e-4, + ) + if __name__ == "__main__": unittest.main()