mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 16:48:10 +08:00
faster depthwise 1D conv (#2567)
This commit is contained in:
@@ -2,7 +2,6 @@
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
@@ -39,10 +38,11 @@ void explicit_gemm_conv_ND_gpu(
|
||||
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
|
||||
|
||||
// Prepare unfolding kernel
|
||||
std::ostringstream kname;
|
||||
kname << "naive_unfold_nd_" << type_to_name(in_unfolded) << "_" << N;
|
||||
std::string kname;
|
||||
kname.reserve(32);
|
||||
concatenate(kname, "naive_unfold_nd_", type_to_name(in_unfolded), "_", N);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto kernel = d.get_kernel(kname);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
@@ -117,11 +117,12 @@ void explicit_gemm_conv_group_ND_gpu(
|
||||
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
|
||||
|
||||
// Prepare unfolding kernel
|
||||
std::ostringstream kname;
|
||||
kname << "naive_unfold_transpose_nd_" << type_to_name(in_unfolded) << "_"
|
||||
<< N;
|
||||
std::string kname;
|
||||
kname.reserve(32);
|
||||
concatenate(
|
||||
kname, "naive_unfold_transpose_nd_", type_to_name(in_unfolded), "_", N);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto kernel = d.get_kernel(kname);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
@@ -252,18 +253,32 @@ void implicit_gemm_conv_2D_gpu(
|
||||
/* const int swizzle_log = */ swizzle_log};
|
||||
|
||||
// Determine kernel
|
||||
std::ostringstream kname;
|
||||
kname << "implicit_gemm_conv_2d_" << type_to_name(out) << "_bm" << bm << "_bn"
|
||||
<< bn << "_bk" << bk << "_wm" << wm << "_wn" << wn << "_channel_"
|
||||
<< (n_channel_specialization ? std::to_string(n_channel_specialization)
|
||||
: "l")
|
||||
<< "_filter_" << (small_filter ? 's' : 'l');
|
||||
std::string kname;
|
||||
kname.reserve(64);
|
||||
concatenate(
|
||||
kname,
|
||||
"implicit_gemm_conv_2d_",
|
||||
type_to_name(out),
|
||||
"_bm",
|
||||
bm,
|
||||
"_bn",
|
||||
bn,
|
||||
"_bk",
|
||||
bk,
|
||||
"_wm",
|
||||
wm,
|
||||
"_wn",
|
||||
wn,
|
||||
"_channel_",
|
||||
n_channel_specialization ? std::to_string(n_channel_specialization) : "l",
|
||||
"_filter_",
|
||||
small_filter ? 's' : 'l');
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = get_steel_conv_kernel(
|
||||
d,
|
||||
kname.str(),
|
||||
kname,
|
||||
out,
|
||||
bm,
|
||||
bn,
|
||||
@@ -559,11 +574,16 @@ void winograd_conv_2D_gpu(
|
||||
{
|
||||
int bc = 32;
|
||||
int bo = 4;
|
||||
std::ostringstream kname;
|
||||
kname << "winograd_conv_2d_weight_transform_" << type_to_name(out) << "_bc"
|
||||
<< bc;
|
||||
std::string kname;
|
||||
kname.reserve(32);
|
||||
concatenate(
|
||||
kname,
|
||||
"winograd_conv_2d_weight_transform_",
|
||||
type_to_name(out),
|
||||
"_bc",
|
||||
bc);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto kernel = d.get_kernel(kname);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(wt, 0);
|
||||
@@ -587,11 +607,16 @@ void winograd_conv_2D_gpu(
|
||||
int bc = 32;
|
||||
int wm = 2;
|
||||
int wn = 2;
|
||||
std::ostringstream kname;
|
||||
kname << "winograd_conv_2d_input_transform_" << type_to_name(out) << "_bc"
|
||||
<< bc;
|
||||
std::string kname;
|
||||
kname.reserve(32);
|
||||
concatenate(
|
||||
kname,
|
||||
"winograd_conv_2d_input_transform_",
|
||||
type_to_name(out),
|
||||
"_bc",
|
||||
bc);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto kernel = d.get_kernel(kname);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(in_padded, 0);
|
||||
@@ -634,11 +659,16 @@ void winograd_conv_2D_gpu(
|
||||
int bc = 32;
|
||||
int wm = 2;
|
||||
int wn = 2;
|
||||
std::ostringstream kname;
|
||||
kname << "winograd_conv_2d_output_transform_" << type_to_name(out) << "_bo"
|
||||
<< bc;
|
||||
std::string kname;
|
||||
kname.reserve(32);
|
||||
concatenate(
|
||||
kname,
|
||||
"winograd_conv_2d_output_transform_",
|
||||
type_to_name(out),
|
||||
"_bo",
|
||||
bc);
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
auto kernel = d.get_kernel(kname);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(out_wg, 0);
|
||||
@@ -660,9 +690,9 @@ void depthwise_conv_2D_gpu(
|
||||
const array& wt,
|
||||
array out,
|
||||
const MLXConvParams<2>& conv_params) {
|
||||
std::ostringstream kname;
|
||||
kname << "depthwise_conv_2d_" << type_to_name(out);
|
||||
std::string base_name = kname.str();
|
||||
std::string base_name;
|
||||
base_name.reserve(32);
|
||||
concatenate(base_name, "depthwise_conv_2d_", type_to_name(out));
|
||||
|
||||
const int N = conv_params.N;
|
||||
const int ker_h = conv_params.wS[0];
|
||||
@@ -685,15 +715,18 @@ void depthwise_conv_2D_gpu(
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
kname << "_ker_h_" << ker_h
|
||||
<< "_ker_w_" << ker_w
|
||||
<< "_str_h_" << str_h
|
||||
<< "_str_w_" << str_w
|
||||
<< "_tgp_h_" << th
|
||||
<< "_tgp_w_" << tw
|
||||
<< "_do_flip_" << (do_flip ? 't' : 'n'); // clang-format on
|
||||
|
||||
std::string hash_name = kname.str();
|
||||
std::string hash_name;
|
||||
hash_name.reserve(64);
|
||||
concatenate(
|
||||
hash_name,
|
||||
base_name,
|
||||
"_ker_h_", ker_h,
|
||||
"_ker_w_", ker_w,
|
||||
"_str_h_", str_h,
|
||||
"_str_w_", str_w,
|
||||
"_tgp_h_", th,
|
||||
"_tgp_w_", tw,
|
||||
"_do_flip_", do_flip ? 't' : 'n'); // clang-format on
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(base_name, hash_name, func_consts);
|
||||
@@ -774,6 +807,56 @@ void dispatch_conv_2D_gpu(
|
||||
}
|
||||
}
|
||||
|
||||
void depthwise_conv_1D_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& in,
|
||||
array wt,
|
||||
array out) {
|
||||
bool large = in.size() > INT32_MAX || in.data_size() > INT32_MAX;
|
||||
std::string base_name;
|
||||
base_name.reserve(32);
|
||||
concatenate(
|
||||
base_name,
|
||||
"depthwise_conv_1d_",
|
||||
large ? "_large" : "",
|
||||
type_to_name(out));
|
||||
|
||||
if (!wt.flags().row_contiguous) {
|
||||
wt = contiguous_copy_gpu(wt, s);
|
||||
d.add_temporary(wt, s.index);
|
||||
}
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(base_name);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
auto B = in.shape(0);
|
||||
auto Tout = out.shape(1);
|
||||
auto D = in.shape(2);
|
||||
auto K = wt.shape(1);
|
||||
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_input_array(wt, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
if (large) {
|
||||
int64_t strides[3] = {in.strides(0), in.strides(1), in.strides(2)};
|
||||
compute_encoder.set_bytes(strides, 3, 3);
|
||||
|
||||
} else {
|
||||
int strides[3] = {
|
||||
static_cast<int>(in.strides(0)),
|
||||
static_cast<int>(in.strides(1)),
|
||||
static_cast<int>(in.strides(2))};
|
||||
compute_encoder.set_bytes(strides, 3, 3);
|
||||
}
|
||||
|
||||
compute_encoder.set_bytes(K, 4);
|
||||
auto group_dims = get_block_dims(D, Tout, B);
|
||||
MTL::Size grid_dims = MTL::Size(D, Tout, B);
|
||||
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void conv_1D_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
@@ -790,8 +873,15 @@ void conv_1D_gpu(
|
||||
bool is_idil_one = in_dilation[0] == 1;
|
||||
int C = in.shape(2);
|
||||
int O = wt.shape(0);
|
||||
const int C_per_group = in.shape(2) / groups;
|
||||
const int O_per_group = wt.shape(0) / groups;
|
||||
// Fast path for fully separable 1D convolution
|
||||
if (is_idil_one && (groups == C) && groups == O && wt_strides[0] == 1 &&
|
||||
wt_dilation[0] == 1 && padding[0] == 0 && !flip) {
|
||||
depthwise_conv_1D_gpu(s, d, in, wt, out);
|
||||
return;
|
||||
}
|
||||
|
||||
const int C_per_group = C / groups;
|
||||
const int O_per_group = O / groups;
|
||||
|
||||
// Direct to implicit gemm conv
|
||||
if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) &&
|
||||
|
@@ -288,6 +288,40 @@ instantiate_depthconv2d(float32, float);
|
||||
instantiate_depthconv2d(float16, half);
|
||||
instantiate_depthconv2d(bfloat16, bfloat16_t);
|
||||
|
||||
template <typename T, typename IdxT>
|
||||
[[kernel]] void depthwise_conv_1d(
|
||||
const device T* in [[buffer(0)]],
|
||||
const device T* w [[buffer(1)]],
|
||||
device T* out [[buffer(2)]],
|
||||
constant const IdxT strides[3],
|
||||
constant const int& kernel_size,
|
||||
uint3 tid [[thread_position_in_grid]],
|
||||
uint3 grid_dim [[threads_per_grid]]) {
|
||||
out += (tid.z * static_cast<IdxT>(grid_dim.y) + tid.y) * grid_dim.x + tid.x;
|
||||
in += tid.z * strides[0] + tid.y * strides[1] + tid.x * strides[2];
|
||||
w += tid.x * kernel_size;
|
||||
|
||||
float acc = 0.0;
|
||||
for (int i = 0; i < kernel_size; ++i) {
|
||||
acc += static_cast<float>(in[0]) * w[i];
|
||||
in += strides[1];
|
||||
}
|
||||
*out = static_cast<T>(acc);
|
||||
}
|
||||
|
||||
#define instantiate_depthconv1d(iname, itype) \
|
||||
instantiate_kernel( \
|
||||
"depthwise_conv_1d_" #iname, depthwise_conv_1d, itype, int32_t) \
|
||||
instantiate_kernel( \
|
||||
"depthwise_conv_1d_" #iname "_large", \
|
||||
depthwise_conv_1d, \
|
||||
itype, \
|
||||
int64_t)
|
||||
|
||||
instantiate_depthconv1d(float32, float);
|
||||
instantiate_depthconv1d(float16, half);
|
||||
instantiate_depthconv1d(bfloat16, bfloat16_t);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
/// Winograd kernels
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
Reference in New Issue
Block a user