mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	faster depthwise 1D conv (#2567)
This commit is contained in:
		| @@ -2,7 +2,6 @@ | ||||
| #include <algorithm> | ||||
| #include <cassert> | ||||
| #include <numeric> | ||||
| #include <sstream> | ||||
|  | ||||
| #include "mlx/backend/gpu/copy.h" | ||||
| #include "mlx/backend/metal/device.h" | ||||
| @@ -39,10 +38,11 @@ void explicit_gemm_conv_ND_gpu( | ||||
|   in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes())); | ||||
|  | ||||
|   // Prepare unfolding kernel | ||||
|   std::ostringstream kname; | ||||
|   kname << "naive_unfold_nd_" << type_to_name(in_unfolded) << "_" << N; | ||||
|   std::string kname; | ||||
|   kname.reserve(32); | ||||
|   concatenate(kname, "naive_unfold_nd_", type_to_name(in_unfolded), "_", N); | ||||
|   auto& compute_encoder = d.get_command_encoder(s.index); | ||||
|   auto kernel = d.get_kernel(kname.str()); | ||||
|   auto kernel = d.get_kernel(kname); | ||||
|   compute_encoder.set_compute_pipeline_state(kernel); | ||||
|  | ||||
|   compute_encoder.set_input_array(in, 0); | ||||
| @@ -117,11 +117,12 @@ void explicit_gemm_conv_group_ND_gpu( | ||||
|   in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes())); | ||||
|  | ||||
|   // Prepare unfolding kernel | ||||
|   std::ostringstream kname; | ||||
|   kname << "naive_unfold_transpose_nd_" << type_to_name(in_unfolded) << "_" | ||||
|         << N; | ||||
|   std::string kname; | ||||
|   kname.reserve(32); | ||||
|   concatenate( | ||||
|       kname, "naive_unfold_transpose_nd_", type_to_name(in_unfolded), "_", N); | ||||
|   auto& compute_encoder = d.get_command_encoder(s.index); | ||||
|   auto kernel = d.get_kernel(kname.str()); | ||||
|   auto kernel = d.get_kernel(kname); | ||||
|   compute_encoder.set_compute_pipeline_state(kernel); | ||||
|  | ||||
|   compute_encoder.set_input_array(in, 0); | ||||
| @@ -252,18 +253,32 @@ void implicit_gemm_conv_2D_gpu( | ||||
|       /* const int swizzle_log = */ swizzle_log}; | ||||
|  | ||||
|   // Determine kernel | ||||
|   std::ostringstream kname; | ||||
|   kname << "implicit_gemm_conv_2d_" << type_to_name(out) << "_bm" << bm << "_bn" | ||||
|         << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn << "_channel_" | ||||
|         << (n_channel_specialization ? std::to_string(n_channel_specialization) | ||||
|                                      : "l") | ||||
|         << "_filter_" << (small_filter ? 's' : 'l'); | ||||
|   std::string kname; | ||||
|   kname.reserve(64); | ||||
|   concatenate( | ||||
|       kname, | ||||
|       "implicit_gemm_conv_2d_", | ||||
|       type_to_name(out), | ||||
|       "_bm", | ||||
|       bm, | ||||
|       "_bn", | ||||
|       bn, | ||||
|       "_bk", | ||||
|       bk, | ||||
|       "_wm", | ||||
|       wm, | ||||
|       "_wn", | ||||
|       wn, | ||||
|       "_channel_", | ||||
|       n_channel_specialization ? std::to_string(n_channel_specialization) : "l", | ||||
|       "_filter_", | ||||
|       small_filter ? 's' : 'l'); | ||||
|  | ||||
|   // Encode and dispatch kernel | ||||
|   auto& compute_encoder = d.get_command_encoder(s.index); | ||||
|   auto kernel = get_steel_conv_kernel( | ||||
|       d, | ||||
|       kname.str(), | ||||
|       kname, | ||||
|       out, | ||||
|       bm, | ||||
|       bn, | ||||
| @@ -559,11 +574,16 @@ void winograd_conv_2D_gpu( | ||||
|   { | ||||
|     int bc = 32; | ||||
|     int bo = 4; | ||||
|     std::ostringstream kname; | ||||
|     kname << "winograd_conv_2d_weight_transform_" << type_to_name(out) << "_bc" | ||||
|           << bc; | ||||
|     std::string kname; | ||||
|     kname.reserve(32); | ||||
|     concatenate( | ||||
|         kname, | ||||
|         "winograd_conv_2d_weight_transform_", | ||||
|         type_to_name(out), | ||||
|         "_bc", | ||||
|         bc); | ||||
|     auto& compute_encoder = d.get_command_encoder(s.index); | ||||
|     auto kernel = d.get_kernel(kname.str()); | ||||
|     auto kernel = d.get_kernel(kname); | ||||
|     compute_encoder.set_compute_pipeline_state(kernel); | ||||
|  | ||||
|     compute_encoder.set_input_array(wt, 0); | ||||
| @@ -587,11 +607,16 @@ void winograd_conv_2D_gpu( | ||||
|     int bc = 32; | ||||
|     int wm = 2; | ||||
|     int wn = 2; | ||||
|     std::ostringstream kname; | ||||
|     kname << "winograd_conv_2d_input_transform_" << type_to_name(out) << "_bc" | ||||
|           << bc; | ||||
|     std::string kname; | ||||
|     kname.reserve(32); | ||||
|     concatenate( | ||||
|         kname, | ||||
|         "winograd_conv_2d_input_transform_", | ||||
|         type_to_name(out), | ||||
|         "_bc", | ||||
|         bc); | ||||
|     auto& compute_encoder = d.get_command_encoder(s.index); | ||||
|     auto kernel = d.get_kernel(kname.str()); | ||||
|     auto kernel = d.get_kernel(kname); | ||||
|     compute_encoder.set_compute_pipeline_state(kernel); | ||||
|  | ||||
|     compute_encoder.set_input_array(in_padded, 0); | ||||
| @@ -634,11 +659,16 @@ void winograd_conv_2D_gpu( | ||||
|     int bc = 32; | ||||
|     int wm = 2; | ||||
|     int wn = 2; | ||||
|     std::ostringstream kname; | ||||
|     kname << "winograd_conv_2d_output_transform_" << type_to_name(out) << "_bo" | ||||
|           << bc; | ||||
|     std::string kname; | ||||
|     kname.reserve(32); | ||||
|     concatenate( | ||||
|         kname, | ||||
|         "winograd_conv_2d_output_transform_", | ||||
|         type_to_name(out), | ||||
|         "_bo", | ||||
|         bc); | ||||
|     auto& compute_encoder = d.get_command_encoder(s.index); | ||||
|     auto kernel = d.get_kernel(kname.str()); | ||||
|     auto kernel = d.get_kernel(kname); | ||||
|     compute_encoder.set_compute_pipeline_state(kernel); | ||||
|  | ||||
|     compute_encoder.set_input_array(out_wg, 0); | ||||
| @@ -660,9 +690,9 @@ void depthwise_conv_2D_gpu( | ||||
|     const array& wt, | ||||
|     array out, | ||||
|     const MLXConvParams<2>& conv_params) { | ||||
|   std::ostringstream kname; | ||||
|   kname << "depthwise_conv_2d_" << type_to_name(out); | ||||
|   std::string base_name = kname.str(); | ||||
|   std::string base_name; | ||||
|   base_name.reserve(32); | ||||
|   concatenate(base_name, "depthwise_conv_2d_", type_to_name(out)); | ||||
|  | ||||
|   const int N = conv_params.N; | ||||
|   const int ker_h = conv_params.wS[0]; | ||||
| @@ -685,15 +715,18 @@ void depthwise_conv_2D_gpu( | ||||
|   }; | ||||
|  | ||||
|   // clang-format off | ||||
|   kname << "_ker_h_" << ker_h | ||||
|         << "_ker_w_" << ker_w | ||||
|         << "_str_h_" << str_h | ||||
|         << "_str_w_" << str_w | ||||
|         << "_tgp_h_" << th | ||||
|         << "_tgp_w_" << tw | ||||
|         << "_do_flip_" << (do_flip ? 't' : 'n'); // clang-format on | ||||
|  | ||||
|   std::string hash_name = kname.str(); | ||||
|   std::string hash_name; | ||||
|   hash_name.reserve(64); | ||||
|   concatenate( | ||||
|       hash_name, | ||||
|       base_name, | ||||
|   "_ker_h_", ker_h, | ||||
|   "_ker_w_", ker_w, | ||||
|   "_str_h_", str_h, | ||||
|   "_str_w_", str_w, | ||||
|   "_tgp_h_", th, | ||||
|   "_tgp_w_", tw, | ||||
|   "_do_flip_", do_flip ? 't' : 'n'); // clang-format on | ||||
|  | ||||
|   auto& compute_encoder = d.get_command_encoder(s.index); | ||||
|   auto kernel = d.get_kernel(base_name, hash_name, func_consts); | ||||
| @@ -774,6 +807,56 @@ void dispatch_conv_2D_gpu( | ||||
|   } | ||||
| } | ||||
|  | ||||
| void depthwise_conv_1D_gpu( | ||||
|     const Stream& s, | ||||
|     metal::Device& d, | ||||
|     const array& in, | ||||
|     array wt, | ||||
|     array out) { | ||||
|   bool large = in.size() > INT32_MAX || in.data_size() > INT32_MAX; | ||||
|   std::string base_name; | ||||
|   base_name.reserve(32); | ||||
|   concatenate( | ||||
|       base_name, | ||||
|       "depthwise_conv_1d_", | ||||
|       large ? "_large" : "", | ||||
|       type_to_name(out)); | ||||
|  | ||||
|   if (!wt.flags().row_contiguous) { | ||||
|     wt = contiguous_copy_gpu(wt, s); | ||||
|     d.add_temporary(wt, s.index); | ||||
|   } | ||||
|   auto& compute_encoder = d.get_command_encoder(s.index); | ||||
|   auto kernel = d.get_kernel(base_name); | ||||
|   compute_encoder.set_compute_pipeline_state(kernel); | ||||
|  | ||||
|   auto B = in.shape(0); | ||||
|   auto Tout = out.shape(1); | ||||
|   auto D = in.shape(2); | ||||
|   auto K = wt.shape(1); | ||||
|  | ||||
|   compute_encoder.set_input_array(in, 0); | ||||
|   compute_encoder.set_input_array(wt, 1); | ||||
|   compute_encoder.set_output_array(out, 2); | ||||
|   if (large) { | ||||
|     int64_t strides[3] = {in.strides(0), in.strides(1), in.strides(2)}; | ||||
|     compute_encoder.set_bytes(strides, 3, 3); | ||||
|  | ||||
|   } else { | ||||
|     int strides[3] = { | ||||
|         static_cast<int>(in.strides(0)), | ||||
|         static_cast<int>(in.strides(1)), | ||||
|         static_cast<int>(in.strides(2))}; | ||||
|     compute_encoder.set_bytes(strides, 3, 3); | ||||
|   } | ||||
|  | ||||
|   compute_encoder.set_bytes(K, 4); | ||||
|   auto group_dims = get_block_dims(D, Tout, B); | ||||
|   MTL::Size grid_dims = MTL::Size(D, Tout, B); | ||||
|  | ||||
|   compute_encoder.dispatch_threads(grid_dims, group_dims); | ||||
| } | ||||
|  | ||||
| void conv_1D_gpu( | ||||
|     const Stream& s, | ||||
|     metal::Device& d, | ||||
| @@ -790,8 +873,15 @@ void conv_1D_gpu( | ||||
|   bool is_idil_one = in_dilation[0] == 1; | ||||
|   int C = in.shape(2); | ||||
|   int O = wt.shape(0); | ||||
|   const int C_per_group = in.shape(2) / groups; | ||||
|   const int O_per_group = wt.shape(0) / groups; | ||||
|   // Fast path for fully separable 1D convolution | ||||
|   if (is_idil_one && (groups == C) && groups == O && wt_strides[0] == 1 && | ||||
|       wt_dilation[0] == 1 && padding[0] == 0 && !flip) { | ||||
|     depthwise_conv_1D_gpu(s, d, in, wt, out); | ||||
|     return; | ||||
|   } | ||||
|  | ||||
|   const int C_per_group = C / groups; | ||||
|   const int O_per_group = O / groups; | ||||
|  | ||||
|   // Direct to implicit gemm conv | ||||
|   if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) && | ||||
|   | ||||
| @@ -288,6 +288,40 @@ instantiate_depthconv2d(float32, float); | ||||
| instantiate_depthconv2d(float16, half); | ||||
| instantiate_depthconv2d(bfloat16, bfloat16_t); | ||||
|  | ||||
| template <typename T, typename IdxT> | ||||
| [[kernel]] void depthwise_conv_1d( | ||||
|     const device T* in [[buffer(0)]], | ||||
|     const device T* w [[buffer(1)]], | ||||
|     device T* out [[buffer(2)]], | ||||
|     constant const IdxT strides[3], | ||||
|     constant const int& kernel_size, | ||||
|     uint3 tid [[thread_position_in_grid]], | ||||
|     uint3 grid_dim [[threads_per_grid]]) { | ||||
|   out += (tid.z * static_cast<IdxT>(grid_dim.y) + tid.y) * grid_dim.x + tid.x; | ||||
|   in += tid.z * strides[0] + tid.y * strides[1] + tid.x * strides[2]; | ||||
|   w += tid.x * kernel_size; | ||||
|  | ||||
|   float acc = 0.0; | ||||
|   for (int i = 0; i < kernel_size; ++i) { | ||||
|     acc += static_cast<float>(in[0]) * w[i]; | ||||
|     in += strides[1]; | ||||
|   } | ||||
|   *out = static_cast<T>(acc); | ||||
| } | ||||
|  | ||||
| #define instantiate_depthconv1d(iname, itype)                         \ | ||||
|   instantiate_kernel(                                                 \ | ||||
|       "depthwise_conv_1d_" #iname, depthwise_conv_1d, itype, int32_t) \ | ||||
|       instantiate_kernel(                                             \ | ||||
|           "depthwise_conv_1d_" #iname "_large",                       \ | ||||
|           depthwise_conv_1d,                                          \ | ||||
|           itype,                                                      \ | ||||
|           int64_t) | ||||
|  | ||||
| instantiate_depthconv1d(float32, float); | ||||
| instantiate_depthconv1d(float16, half); | ||||
| instantiate_depthconv1d(bfloat16, bfloat16_t); | ||||
|  | ||||
| /////////////////////////////////////////////////////////////////////////////// | ||||
| /// Winograd kernels | ||||
| /////////////////////////////////////////////////////////////////////////////// | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun