mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-20 01:18:12 +08:00
faster depthwise 1D conv (#2567)
This commit is contained in:
@@ -2,7 +2,6 @@
|
|||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <sstream>
|
|
||||||
|
|
||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
@@ -39,10 +38,11 @@ void explicit_gemm_conv_ND_gpu(
|
|||||||
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
|
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
|
||||||
|
|
||||||
// Prepare unfolding kernel
|
// Prepare unfolding kernel
|
||||||
std::ostringstream kname;
|
std::string kname;
|
||||||
kname << "naive_unfold_nd_" << type_to_name(in_unfolded) << "_" << N;
|
kname.reserve(32);
|
||||||
|
concatenate(kname, "naive_unfold_nd_", type_to_name(in_unfolded), "_", N);
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
auto kernel = d.get_kernel(kname.str());
|
auto kernel = d.get_kernel(kname);
|
||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
compute_encoder.set_input_array(in, 0);
|
compute_encoder.set_input_array(in, 0);
|
||||||
@@ -117,11 +117,12 @@ void explicit_gemm_conv_group_ND_gpu(
|
|||||||
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
|
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
|
||||||
|
|
||||||
// Prepare unfolding kernel
|
// Prepare unfolding kernel
|
||||||
std::ostringstream kname;
|
std::string kname;
|
||||||
kname << "naive_unfold_transpose_nd_" << type_to_name(in_unfolded) << "_"
|
kname.reserve(32);
|
||||||
<< N;
|
concatenate(
|
||||||
|
kname, "naive_unfold_transpose_nd_", type_to_name(in_unfolded), "_", N);
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
auto kernel = d.get_kernel(kname.str());
|
auto kernel = d.get_kernel(kname);
|
||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
compute_encoder.set_input_array(in, 0);
|
compute_encoder.set_input_array(in, 0);
|
||||||
@@ -252,18 +253,32 @@ void implicit_gemm_conv_2D_gpu(
|
|||||||
/* const int swizzle_log = */ swizzle_log};
|
/* const int swizzle_log = */ swizzle_log};
|
||||||
|
|
||||||
// Determine kernel
|
// Determine kernel
|
||||||
std::ostringstream kname;
|
std::string kname;
|
||||||
kname << "implicit_gemm_conv_2d_" << type_to_name(out) << "_bm" << bm << "_bn"
|
kname.reserve(64);
|
||||||
<< bn << "_bk" << bk << "_wm" << wm << "_wn" << wn << "_channel_"
|
concatenate(
|
||||||
<< (n_channel_specialization ? std::to_string(n_channel_specialization)
|
kname,
|
||||||
: "l")
|
"implicit_gemm_conv_2d_",
|
||||||
<< "_filter_" << (small_filter ? 's' : 'l');
|
type_to_name(out),
|
||||||
|
"_bm",
|
||||||
|
bm,
|
||||||
|
"_bn",
|
||||||
|
bn,
|
||||||
|
"_bk",
|
||||||
|
bk,
|
||||||
|
"_wm",
|
||||||
|
wm,
|
||||||
|
"_wn",
|
||||||
|
wn,
|
||||||
|
"_channel_",
|
||||||
|
n_channel_specialization ? std::to_string(n_channel_specialization) : "l",
|
||||||
|
"_filter_",
|
||||||
|
small_filter ? 's' : 'l');
|
||||||
|
|
||||||
// Encode and dispatch kernel
|
// Encode and dispatch kernel
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
auto kernel = get_steel_conv_kernel(
|
auto kernel = get_steel_conv_kernel(
|
||||||
d,
|
d,
|
||||||
kname.str(),
|
kname,
|
||||||
out,
|
out,
|
||||||
bm,
|
bm,
|
||||||
bn,
|
bn,
|
||||||
@@ -559,11 +574,16 @@ void winograd_conv_2D_gpu(
|
|||||||
{
|
{
|
||||||
int bc = 32;
|
int bc = 32;
|
||||||
int bo = 4;
|
int bo = 4;
|
||||||
std::ostringstream kname;
|
std::string kname;
|
||||||
kname << "winograd_conv_2d_weight_transform_" << type_to_name(out) << "_bc"
|
kname.reserve(32);
|
||||||
<< bc;
|
concatenate(
|
||||||
|
kname,
|
||||||
|
"winograd_conv_2d_weight_transform_",
|
||||||
|
type_to_name(out),
|
||||||
|
"_bc",
|
||||||
|
bc);
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
auto kernel = d.get_kernel(kname.str());
|
auto kernel = d.get_kernel(kname);
|
||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
compute_encoder.set_input_array(wt, 0);
|
compute_encoder.set_input_array(wt, 0);
|
||||||
@@ -587,11 +607,16 @@ void winograd_conv_2D_gpu(
|
|||||||
int bc = 32;
|
int bc = 32;
|
||||||
int wm = 2;
|
int wm = 2;
|
||||||
int wn = 2;
|
int wn = 2;
|
||||||
std::ostringstream kname;
|
std::string kname;
|
||||||
kname << "winograd_conv_2d_input_transform_" << type_to_name(out) << "_bc"
|
kname.reserve(32);
|
||||||
<< bc;
|
concatenate(
|
||||||
|
kname,
|
||||||
|
"winograd_conv_2d_input_transform_",
|
||||||
|
type_to_name(out),
|
||||||
|
"_bc",
|
||||||
|
bc);
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
auto kernel = d.get_kernel(kname.str());
|
auto kernel = d.get_kernel(kname);
|
||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
compute_encoder.set_input_array(in_padded, 0);
|
compute_encoder.set_input_array(in_padded, 0);
|
||||||
@@ -634,11 +659,16 @@ void winograd_conv_2D_gpu(
|
|||||||
int bc = 32;
|
int bc = 32;
|
||||||
int wm = 2;
|
int wm = 2;
|
||||||
int wn = 2;
|
int wn = 2;
|
||||||
std::ostringstream kname;
|
std::string kname;
|
||||||
kname << "winograd_conv_2d_output_transform_" << type_to_name(out) << "_bo"
|
kname.reserve(32);
|
||||||
<< bc;
|
concatenate(
|
||||||
|
kname,
|
||||||
|
"winograd_conv_2d_output_transform_",
|
||||||
|
type_to_name(out),
|
||||||
|
"_bo",
|
||||||
|
bc);
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
auto kernel = d.get_kernel(kname.str());
|
auto kernel = d.get_kernel(kname);
|
||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
compute_encoder.set_input_array(out_wg, 0);
|
compute_encoder.set_input_array(out_wg, 0);
|
||||||
@@ -660,9 +690,9 @@ void depthwise_conv_2D_gpu(
|
|||||||
const array& wt,
|
const array& wt,
|
||||||
array out,
|
array out,
|
||||||
const MLXConvParams<2>& conv_params) {
|
const MLXConvParams<2>& conv_params) {
|
||||||
std::ostringstream kname;
|
std::string base_name;
|
||||||
kname << "depthwise_conv_2d_" << type_to_name(out);
|
base_name.reserve(32);
|
||||||
std::string base_name = kname.str();
|
concatenate(base_name, "depthwise_conv_2d_", type_to_name(out));
|
||||||
|
|
||||||
const int N = conv_params.N;
|
const int N = conv_params.N;
|
||||||
const int ker_h = conv_params.wS[0];
|
const int ker_h = conv_params.wS[0];
|
||||||
@@ -685,15 +715,18 @@ void depthwise_conv_2D_gpu(
|
|||||||
};
|
};
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
kname << "_ker_h_" << ker_h
|
std::string hash_name;
|
||||||
<< "_ker_w_" << ker_w
|
hash_name.reserve(64);
|
||||||
<< "_str_h_" << str_h
|
concatenate(
|
||||||
<< "_str_w_" << str_w
|
hash_name,
|
||||||
<< "_tgp_h_" << th
|
base_name,
|
||||||
<< "_tgp_w_" << tw
|
"_ker_h_", ker_h,
|
||||||
<< "_do_flip_" << (do_flip ? 't' : 'n'); // clang-format on
|
"_ker_w_", ker_w,
|
||||||
|
"_str_h_", str_h,
|
||||||
std::string hash_name = kname.str();
|
"_str_w_", str_w,
|
||||||
|
"_tgp_h_", th,
|
||||||
|
"_tgp_w_", tw,
|
||||||
|
"_do_flip_", do_flip ? 't' : 'n'); // clang-format on
|
||||||
|
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
auto kernel = d.get_kernel(base_name, hash_name, func_consts);
|
auto kernel = d.get_kernel(base_name, hash_name, func_consts);
|
||||||
@@ -774,6 +807,56 @@ void dispatch_conv_2D_gpu(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void depthwise_conv_1D_gpu(
|
||||||
|
const Stream& s,
|
||||||
|
metal::Device& d,
|
||||||
|
const array& in,
|
||||||
|
array wt,
|
||||||
|
array out) {
|
||||||
|
bool large = in.size() > INT32_MAX || in.data_size() > INT32_MAX;
|
||||||
|
std::string base_name;
|
||||||
|
base_name.reserve(32);
|
||||||
|
concatenate(
|
||||||
|
base_name,
|
||||||
|
"depthwise_conv_1d_",
|
||||||
|
large ? "_large" : "",
|
||||||
|
type_to_name(out));
|
||||||
|
|
||||||
|
if (!wt.flags().row_contiguous) {
|
||||||
|
wt = contiguous_copy_gpu(wt, s);
|
||||||
|
d.add_temporary(wt, s.index);
|
||||||
|
}
|
||||||
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
auto kernel = d.get_kernel(base_name);
|
||||||
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
|
auto B = in.shape(0);
|
||||||
|
auto Tout = out.shape(1);
|
||||||
|
auto D = in.shape(2);
|
||||||
|
auto K = wt.shape(1);
|
||||||
|
|
||||||
|
compute_encoder.set_input_array(in, 0);
|
||||||
|
compute_encoder.set_input_array(wt, 1);
|
||||||
|
compute_encoder.set_output_array(out, 2);
|
||||||
|
if (large) {
|
||||||
|
int64_t strides[3] = {in.strides(0), in.strides(1), in.strides(2)};
|
||||||
|
compute_encoder.set_bytes(strides, 3, 3);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
int strides[3] = {
|
||||||
|
static_cast<int>(in.strides(0)),
|
||||||
|
static_cast<int>(in.strides(1)),
|
||||||
|
static_cast<int>(in.strides(2))};
|
||||||
|
compute_encoder.set_bytes(strides, 3, 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
compute_encoder.set_bytes(K, 4);
|
||||||
|
auto group_dims = get_block_dims(D, Tout, B);
|
||||||
|
MTL::Size grid_dims = MTL::Size(D, Tout, B);
|
||||||
|
|
||||||
|
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
|
||||||
void conv_1D_gpu(
|
void conv_1D_gpu(
|
||||||
const Stream& s,
|
const Stream& s,
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
@@ -790,8 +873,15 @@ void conv_1D_gpu(
|
|||||||
bool is_idil_one = in_dilation[0] == 1;
|
bool is_idil_one = in_dilation[0] == 1;
|
||||||
int C = in.shape(2);
|
int C = in.shape(2);
|
||||||
int O = wt.shape(0);
|
int O = wt.shape(0);
|
||||||
const int C_per_group = in.shape(2) / groups;
|
// Fast path for fully separable 1D convolution
|
||||||
const int O_per_group = wt.shape(0) / groups;
|
if (is_idil_one && (groups == C) && groups == O && wt_strides[0] == 1 &&
|
||||||
|
wt_dilation[0] == 1 && padding[0] == 0 && !flip) {
|
||||||
|
depthwise_conv_1D_gpu(s, d, in, wt, out);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int C_per_group = C / groups;
|
||||||
|
const int O_per_group = O / groups;
|
||||||
|
|
||||||
// Direct to implicit gemm conv
|
// Direct to implicit gemm conv
|
||||||
if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) &&
|
if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) &&
|
||||||
|
@@ -288,6 +288,40 @@ instantiate_depthconv2d(float32, float);
|
|||||||
instantiate_depthconv2d(float16, half);
|
instantiate_depthconv2d(float16, half);
|
||||||
instantiate_depthconv2d(bfloat16, bfloat16_t);
|
instantiate_depthconv2d(bfloat16, bfloat16_t);
|
||||||
|
|
||||||
|
template <typename T, typename IdxT>
|
||||||
|
[[kernel]] void depthwise_conv_1d(
|
||||||
|
const device T* in [[buffer(0)]],
|
||||||
|
const device T* w [[buffer(1)]],
|
||||||
|
device T* out [[buffer(2)]],
|
||||||
|
constant const IdxT strides[3],
|
||||||
|
constant const int& kernel_size,
|
||||||
|
uint3 tid [[thread_position_in_grid]],
|
||||||
|
uint3 grid_dim [[threads_per_grid]]) {
|
||||||
|
out += (tid.z * static_cast<IdxT>(grid_dim.y) + tid.y) * grid_dim.x + tid.x;
|
||||||
|
in += tid.z * strides[0] + tid.y * strides[1] + tid.x * strides[2];
|
||||||
|
w += tid.x * kernel_size;
|
||||||
|
|
||||||
|
float acc = 0.0;
|
||||||
|
for (int i = 0; i < kernel_size; ++i) {
|
||||||
|
acc += static_cast<float>(in[0]) * w[i];
|
||||||
|
in += strides[1];
|
||||||
|
}
|
||||||
|
*out = static_cast<T>(acc);
|
||||||
|
}
|
||||||
|
|
||||||
|
#define instantiate_depthconv1d(iname, itype) \
|
||||||
|
instantiate_kernel( \
|
||||||
|
"depthwise_conv_1d_" #iname, depthwise_conv_1d, itype, int32_t) \
|
||||||
|
instantiate_kernel( \
|
||||||
|
"depthwise_conv_1d_" #iname "_large", \
|
||||||
|
depthwise_conv_1d, \
|
||||||
|
itype, \
|
||||||
|
int64_t)
|
||||||
|
|
||||||
|
instantiate_depthconv1d(float32, float);
|
||||||
|
instantiate_depthconv1d(float16, half);
|
||||||
|
instantiate_depthconv1d(bfloat16, bfloat16_t);
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
/// Winograd kernels
|
/// Winograd kernels
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
Reference in New Issue
Block a user