diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index 06d058eae..b4a674ff0 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -2,7 +2,6 @@ #include #include #include -#include #include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" @@ -39,10 +38,11 @@ void explicit_gemm_conv_ND_gpu( in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes())); // Prepare unfolding kernel - std::ostringstream kname; - kname << "naive_unfold_nd_" << type_to_name(in_unfolded) << "_" << N; + std::string kname; + kname.reserve(32); + concatenate(kname, "naive_unfold_nd_", type_to_name(in_unfolded), "_", N); auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname.str()); + auto kernel = d.get_kernel(kname); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(in, 0); @@ -117,11 +117,12 @@ void explicit_gemm_conv_group_ND_gpu( in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes())); // Prepare unfolding kernel - std::ostringstream kname; - kname << "naive_unfold_transpose_nd_" << type_to_name(in_unfolded) << "_" - << N; + std::string kname; + kname.reserve(32); + concatenate( + kname, "naive_unfold_transpose_nd_", type_to_name(in_unfolded), "_", N); auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname.str()); + auto kernel = d.get_kernel(kname); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(in, 0); @@ -252,18 +253,32 @@ void implicit_gemm_conv_2D_gpu( /* const int swizzle_log = */ swizzle_log}; // Determine kernel - std::ostringstream kname; - kname << "implicit_gemm_conv_2d_" << type_to_name(out) << "_bm" << bm << "_bn" - << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn << "_channel_" - << (n_channel_specialization ? std::to_string(n_channel_specialization) - : "l") - << "_filter_" << (small_filter ? 's' : 'l'); + std::string kname; + kname.reserve(64); + concatenate( + kname, + "implicit_gemm_conv_2d_", + type_to_name(out), + "_bm", + bm, + "_bn", + bn, + "_bk", + bk, + "_wm", + wm, + "_wn", + wn, + "_channel_", + n_channel_specialization ? std::to_string(n_channel_specialization) : "l", + "_filter_", + small_filter ? 's' : 'l'); // Encode and dispatch kernel auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = get_steel_conv_kernel( d, - kname.str(), + kname, out, bm, bn, @@ -559,11 +574,16 @@ void winograd_conv_2D_gpu( { int bc = 32; int bo = 4; - std::ostringstream kname; - kname << "winograd_conv_2d_weight_transform_" << type_to_name(out) << "_bc" - << bc; + std::string kname; + kname.reserve(32); + concatenate( + kname, + "winograd_conv_2d_weight_transform_", + type_to_name(out), + "_bc", + bc); auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname.str()); + auto kernel = d.get_kernel(kname); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(wt, 0); @@ -587,11 +607,16 @@ void winograd_conv_2D_gpu( int bc = 32; int wm = 2; int wn = 2; - std::ostringstream kname; - kname << "winograd_conv_2d_input_transform_" << type_to_name(out) << "_bc" - << bc; + std::string kname; + kname.reserve(32); + concatenate( + kname, + "winograd_conv_2d_input_transform_", + type_to_name(out), + "_bc", + bc); auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname.str()); + auto kernel = d.get_kernel(kname); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(in_padded, 0); @@ -634,11 +659,16 @@ void winograd_conv_2D_gpu( int bc = 32; int wm = 2; int wn = 2; - std::ostringstream kname; - kname << "winograd_conv_2d_output_transform_" << type_to_name(out) << "_bo" - << bc; + std::string kname; + kname.reserve(32); + concatenate( + kname, + "winograd_conv_2d_output_transform_", + type_to_name(out), + "_bo", + bc); auto& compute_encoder = d.get_command_encoder(s.index); - auto kernel = d.get_kernel(kname.str()); + auto kernel = d.get_kernel(kname); compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_input_array(out_wg, 0); @@ -660,9 +690,9 @@ void depthwise_conv_2D_gpu( const array& wt, array out, const MLXConvParams<2>& conv_params) { - std::ostringstream kname; - kname << "depthwise_conv_2d_" << type_to_name(out); - std::string base_name = kname.str(); + std::string base_name; + base_name.reserve(32); + concatenate(base_name, "depthwise_conv_2d_", type_to_name(out)); const int N = conv_params.N; const int ker_h = conv_params.wS[0]; @@ -685,15 +715,18 @@ void depthwise_conv_2D_gpu( }; // clang-format off - kname << "_ker_h_" << ker_h - << "_ker_w_" << ker_w - << "_str_h_" << str_h - << "_str_w_" << str_w - << "_tgp_h_" << th - << "_tgp_w_" << tw - << "_do_flip_" << (do_flip ? 't' : 'n'); // clang-format on - - std::string hash_name = kname.str(); + std::string hash_name; + hash_name.reserve(64); + concatenate( + hash_name, + base_name, + "_ker_h_", ker_h, + "_ker_w_", ker_w, + "_str_h_", str_h, + "_str_w_", str_w, + "_tgp_h_", th, + "_tgp_w_", tw, + "_do_flip_", do_flip ? 't' : 'n'); // clang-format on auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(base_name, hash_name, func_consts); @@ -774,6 +807,56 @@ void dispatch_conv_2D_gpu( } } +void depthwise_conv_1D_gpu( + const Stream& s, + metal::Device& d, + const array& in, + array wt, + array out) { + bool large = in.size() > INT32_MAX || in.data_size() > INT32_MAX; + std::string base_name; + base_name.reserve(32); + concatenate( + base_name, + "depthwise_conv_1d_", + large ? "_large" : "", + type_to_name(out)); + + if (!wt.flags().row_contiguous) { + wt = contiguous_copy_gpu(wt, s); + d.add_temporary(wt, s.index); + } + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(base_name); + compute_encoder.set_compute_pipeline_state(kernel); + + auto B = in.shape(0); + auto Tout = out.shape(1); + auto D = in.shape(2); + auto K = wt.shape(1); + + compute_encoder.set_input_array(in, 0); + compute_encoder.set_input_array(wt, 1); + compute_encoder.set_output_array(out, 2); + if (large) { + int64_t strides[3] = {in.strides(0), in.strides(1), in.strides(2)}; + compute_encoder.set_bytes(strides, 3, 3); + + } else { + int strides[3] = { + static_cast(in.strides(0)), + static_cast(in.strides(1)), + static_cast(in.strides(2))}; + compute_encoder.set_bytes(strides, 3, 3); + } + + compute_encoder.set_bytes(K, 4); + auto group_dims = get_block_dims(D, Tout, B); + MTL::Size grid_dims = MTL::Size(D, Tout, B); + + compute_encoder.dispatch_threads(grid_dims, group_dims); +} + void conv_1D_gpu( const Stream& s, metal::Device& d, @@ -790,8 +873,15 @@ void conv_1D_gpu( bool is_idil_one = in_dilation[0] == 1; int C = in.shape(2); int O = wt.shape(0); - const int C_per_group = in.shape(2) / groups; - const int O_per_group = wt.shape(0) / groups; + // Fast path for fully separable 1D convolution + if (is_idil_one && (groups == C) && groups == O && wt_strides[0] == 1 && + wt_dilation[0] == 1 && padding[0] == 0 && !flip) { + depthwise_conv_1D_gpu(s, d, in, wt, out); + return; + } + + const int C_per_group = C / groups; + const int O_per_group = O / groups; // Direct to implicit gemm conv if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) && diff --git a/mlx/backend/metal/kernels/conv.metal b/mlx/backend/metal/kernels/conv.metal index e169ade71..fdec515e2 100644 --- a/mlx/backend/metal/kernels/conv.metal +++ b/mlx/backend/metal/kernels/conv.metal @@ -288,6 +288,40 @@ instantiate_depthconv2d(float32, float); instantiate_depthconv2d(float16, half); instantiate_depthconv2d(bfloat16, bfloat16_t); +template +[[kernel]] void depthwise_conv_1d( + const device T* in [[buffer(0)]], + const device T* w [[buffer(1)]], + device T* out [[buffer(2)]], + constant const IdxT strides[3], + constant const int& kernel_size, + uint3 tid [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + out += (tid.z * static_cast(grid_dim.y) + tid.y) * grid_dim.x + tid.x; + in += tid.z * strides[0] + tid.y * strides[1] + tid.x * strides[2]; + w += tid.x * kernel_size; + + float acc = 0.0; + for (int i = 0; i < kernel_size; ++i) { + acc += static_cast(in[0]) * w[i]; + in += strides[1]; + } + *out = static_cast(acc); +} + +#define instantiate_depthconv1d(iname, itype) \ + instantiate_kernel( \ + "depthwise_conv_1d_" #iname, depthwise_conv_1d, itype, int32_t) \ + instantiate_kernel( \ + "depthwise_conv_1d_" #iname "_large", \ + depthwise_conv_1d, \ + itype, \ + int64_t) + +instantiate_depthconv1d(float32, float); +instantiate_depthconv1d(float16, half); +instantiate_depthconv1d(bfloat16, bfloat16_t); + /////////////////////////////////////////////////////////////////////////////// /// Winograd kernels ///////////////////////////////////////////////////////////////////////////////