mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Depthwise Conv2D optimization (#2036)
- Add new specialized kernel for small kernel (kernels size <= 7), small strides (strides <= 2) depthwise 2d convolutions - Add related tests
This commit is contained in:
@@ -712,6 +712,65 @@ void winograd_conv_2D_gpu(
|
||||
}
|
||||
}
|
||||
|
||||
void depthwise_conv_2D_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
const array& in,
|
||||
const array& wt,
|
||||
array out,
|
||||
const MLXConvParams<2>& conv_params) {
|
||||
std::ostringstream kname;
|
||||
kname << "depthwise_conv_2d_" << type_to_name(out);
|
||||
std::string base_name = kname.str();
|
||||
|
||||
const int N = conv_params.N;
|
||||
const int ker_h = conv_params.wS[0];
|
||||
const int ker_w = conv_params.wS[1];
|
||||
const int str_h = conv_params.str[0];
|
||||
const int str_w = conv_params.str[1];
|
||||
const int tc = 8;
|
||||
const int tw = 8;
|
||||
const int th = 4;
|
||||
const bool do_flip = conv_params.flip;
|
||||
|
||||
metal::MTLFCList func_consts = {
|
||||
{&ker_h, MTL::DataType::DataTypeInt, 00},
|
||||
{&ker_w, MTL::DataType::DataTypeInt, 01},
|
||||
{&str_h, MTL::DataType::DataTypeInt, 10},
|
||||
{&str_w, MTL::DataType::DataTypeInt, 11},
|
||||
{&th, MTL::DataType::DataTypeInt, 100},
|
||||
{&tw, MTL::DataType::DataTypeInt, 101},
|
||||
{&do_flip, MTL::DataType::DataTypeBool, 200},
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
kname << "_ker_h_" << ker_h
|
||||
<< "_ker_w_" << ker_w
|
||||
<< "_str_h_" << str_h
|
||||
<< "_str_w_" << str_w
|
||||
<< "_tgp_h_" << th
|
||||
<< "_tgp_w_" << tw
|
||||
<< "_do_flip_" << (do_flip ? 't' : 'n'); // clang-format on
|
||||
|
||||
std::string hash_name = kname.str();
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto kernel = d.get_kernel(base_name, "mlx", hash_name, func_consts);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_input_array(wt, 1);
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
compute_encoder.set_bytes(conv_params, 3);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(tc, tw, th);
|
||||
MTL::Size grid_dims = MTL::Size(
|
||||
conv_params.C / tc, conv_params.oS[1] / tw, (conv_params.oS[0] / th) * N);
|
||||
|
||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void conv_2D_gpu(
|
||||
const Stream& s,
|
||||
metal::Device& d,
|
||||
@@ -754,11 +813,20 @@ void conv_2D_gpu(
|
||||
bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1;
|
||||
bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1;
|
||||
|
||||
if (groups > 1) {
|
||||
if (is_idil_one && groups > 1) {
|
||||
const int C_per_group = conv_params.C / groups;
|
||||
const int O_per_group = conv_params.O / groups;
|
||||
|
||||
if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) &&
|
||||
if (C_per_group == 1 && O_per_group == 1 && is_kdil_one &&
|
||||
conv_params.wS[0] <= 7 && conv_params.wS[1] <= 7 &&
|
||||
conv_params.str[0] <= 2 && conv_params.str[1] <= 2 &&
|
||||
conv_params.oS[0] % 8 == 0 && conv_params.oS[1] % 8 == 0 &&
|
||||
conv_params.wt_strides[1] == conv_params.wS[1] &&
|
||||
conv_params.C % 16 == 0 && conv_params.C == conv_params.O) {
|
||||
return depthwise_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
|
||||
if ((C_per_group <= 4 || C_per_group % 16 == 0) &&
|
||||
(O_per_group <= 16 || O_per_group % 16 == 0)) {
|
||||
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||
} else {
|
||||
|
||||
Reference in New Issue
Block a user