mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 11:14:32 +08:00
Add groups to 2-D convolutions (#1129)
* Added groups to 2-D convolutions. Only implemented for **some** specializations. Also fixed 1D grouped convs with different kernel strides and added more tests. * fix channels condition
This commit is contained in:
@@ -257,15 +257,19 @@ void implicit_gemm_conv_2D_gpu(
|
||||
const array& wt,
|
||||
array out,
|
||||
const MLXConvParams<2>& conv_params) {
|
||||
const int groups = conv_params.groups;
|
||||
const int C_per_group = conv_params.C / conv_params.groups;
|
||||
const int O_per_group = conv_params.O / conv_params.groups;
|
||||
|
||||
// Deduce implicit gemm size
|
||||
int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1];
|
||||
int implicit_N = conv_params.O;
|
||||
int implicit_K = conv_params.wS[0] * conv_params.wS[1] * conv_params.C;
|
||||
const int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1];
|
||||
const int implicit_N = O_per_group;
|
||||
const int implicit_K = conv_params.wS[0] * conv_params.wS[1] * C_per_group;
|
||||
|
||||
// Determine block and warp tiles
|
||||
int wm = 2, wn = 2;
|
||||
|
||||
int bm = implicit_M >= 8192 && conv_params.C >= 64 ? 64 : 32;
|
||||
int bm = implicit_M >= 8192 && C_per_group >= 64 ? 64 : 32;
|
||||
int bn = (bm == 64 || implicit_N >= 64) ? 64 : 32;
|
||||
int bk = 16;
|
||||
|
||||
@@ -281,15 +285,15 @@ void implicit_gemm_conv_2D_gpu(
|
||||
|
||||
// Fix small channel specialization
|
||||
int n_channel_specialization = 0;
|
||||
int channel_k_iters = ((conv_params.C + bk - 1) / bk);
|
||||
int channel_k_iters = ((C_per_group + bk - 1) / bk);
|
||||
int gemm_k_iters = conv_params.wS[0] * conv_params.wS[1] * channel_k_iters;
|
||||
|
||||
if (conv_params.C <= 2) {
|
||||
if (C_per_group <= 2) {
|
||||
gemm_k_iters = (implicit_K + bk - 1) / bk;
|
||||
n_channel_specialization = conv_params.C;
|
||||
} else if (conv_params.C <= 4) {
|
||||
n_channel_specialization = C_per_group;
|
||||
} else if (C_per_group <= 4) {
|
||||
gemm_k_iters = ((conv_params.wS[0] * conv_params.wS[1] * 4) + bk - 1) / bk;
|
||||
n_channel_specialization = conv_params.C;
|
||||
n_channel_specialization = C_per_group;
|
||||
}
|
||||
|
||||
bool small_filter = (!n_channel_specialization) &&
|
||||
@@ -340,7 +344,7 @@ void implicit_gemm_conv_2D_gpu(
|
||||
size_t grid_dim_x = tn * tile;
|
||||
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, 1);
|
||||
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, groups);
|
||||
|
||||
// Encode arrays
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
@@ -703,6 +707,7 @@ void conv_2D_gpu(
|
||||
const std::vector<int>& wt_strides,
|
||||
const std::vector<int>& wt_dilation,
|
||||
const std::vector<int>& in_dilation,
|
||||
const int groups,
|
||||
bool flip,
|
||||
std::vector<array>& copies) {
|
||||
// Make conv params
|
||||
@@ -718,12 +723,12 @@ void conv_2D_gpu(
|
||||
/* const int kdil[NDIM] = */ {wt_dilation[0], wt_dilation[1]},
|
||||
/* const int idil[NDIM] = */ {in_dilation[0], in_dilation[1]},
|
||||
/* const size_t in_strides[NDIM + 2] = */
|
||||
{in.strides()[0], in.strides()[1], in.strides()[2], in.strides()[3]},
|
||||
{in.strides(0), in.strides(1), in.strides(2), in.strides(3)},
|
||||
/* const size_t wt_strides[NDIM + 2] = */
|
||||
{wt.strides()[0], wt.strides()[1], wt.strides()[2], wt.strides()[3]},
|
||||
{wt.strides(0), wt.strides(1), wt.strides(2), wt.strides(3)},
|
||||
/* const size_t out_strides[NDIM + 2] = */
|
||||
{out.strides()[0], out.strides()[1], out.strides()[2], out.strides()[3]},
|
||||
/* const int groups = */ 1,
|
||||
{out.strides(0), out.strides(1), out.strides(2), out.strides(3)},
|
||||
/* const int groups = */ groups,
|
||||
/* const bool flip = */ flip,
|
||||
};
|
||||
|
||||
@@ -735,6 +740,18 @@ void conv_2D_gpu(
|
||||
bool channels_large = (conv_params.C + conv_params.O) >= 512;
|
||||
bool channels_med = (conv_params.C + conv_params.O) >= 256;
|
||||
|
||||
if (groups > 1) {
|
||||
const int C_per_group = conv_params.C / groups;
|
||||
const int O_per_group = conv_params.O / groups;
|
||||
|
||||
if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) &&
|
||||
(O_per_group <= 16 || O_per_group % 16 == 0)) {
|
||||
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
||||
} else {
|
||||
return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);
|
||||
}
|
||||
}
|
||||
|
||||
// Direct to winograd conv
|
||||
if (!flip && is_stride_one && is_kdil_one && is_idil_one &&
|
||||
conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
|
||||
@@ -860,6 +877,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
kernel_strides_,
|
||||
kernel_dilation_,
|
||||
input_dilation_,
|
||||
groups_,
|
||||
flip_,
|
||||
copies);
|
||||
}
|
||||
|
@@ -109,6 +109,7 @@ template <typename T, int N>
|
||||
bool valid = n < params->N;
|
||||
|
||||
// Unroll dimensions
|
||||
int kernel_stride = 1;
|
||||
for (int i = N - 1; i >= 0; --i) {
|
||||
int os_ = (oS % params->oS[i]);
|
||||
int ws_ = (wS % params->wS[i]);
|
||||
@@ -125,7 +126,8 @@ template <typename T, int N>
|
||||
oS /= params->oS[i];
|
||||
wS /= params->wS[i];
|
||||
|
||||
out += ws_ * params->str[i];
|
||||
out += ws_ * kernel_stride;
|
||||
kernel_stride *= params->wS[i];
|
||||
}
|
||||
|
||||
if (valid) {
|
||||
|
@@ -133,9 +133,15 @@ implicit_gemm_conv_2d(
|
||||
const int c_col = tid_x * BN;
|
||||
const int K = gemm_params->K;
|
||||
const int N = gemm_params->N;
|
||||
const int C_per_group = params->C / params->groups;
|
||||
|
||||
// Groups
|
||||
A += tid.z * C_per_group;
|
||||
B += tid.z * N * K;
|
||||
C += tid.z * N;
|
||||
|
||||
B += c_col * K;
|
||||
C += c_row * N + c_col;
|
||||
C += c_row * (N * params->groups) + c_col;
|
||||
|
||||
const int2 offsets_a(0, c_row);
|
||||
const int2 offsets_b(0, c_col);
|
||||
@@ -171,7 +177,8 @@ implicit_gemm_conv_2d(
|
||||
// Store results to device memory
|
||||
short tgp_bm = min(BM, gemm_params->M - c_row);
|
||||
short tgp_bn = min(BN, gemm_params->N - c_col);
|
||||
mma_op.store_result_safe(C, N, short2(tgp_bn, tgp_bm));
|
||||
const int ldc = N * params->groups;
|
||||
mma_op.store_result_safe(C, ldc, short2(tgp_bn, tgp_bm));
|
||||
}
|
||||
|
||||
#define instantiate_implicit_conv_2d( \
|
||||
|
Reference in New Issue
Block a user