change wino dispatch conditoin (#1534)

This commit is contained in:
Awni Hannun 2024-10-28 11:13:44 -07:00 committed by GitHub
parent d3cd26820e
commit 015c247393
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -752,10 +752,6 @@ void conv_2D_gpu(
bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1;
bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1;
bool inp_large = (conv_params.in_strides[0] >= 1ul << 18);
bool channels_large = (conv_params.C + conv_params.O) >= 512;
bool channels_med = (conv_params.C + conv_params.O) >= 256;
if (groups > 1) {
const int C_per_group = conv_params.C / groups;
const int O_per_group = conv_params.O / groups;
@ -769,10 +765,13 @@ void conv_2D_gpu(
}
// Direct to winograd conv
bool inp_large =
(conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 1ul << 12;
bool channels_large = (conv_params.C + conv_params.O) >= 256;
if (!flip && is_stride_one && is_kdil_one && is_idil_one &&
conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 &&
(channels_large || (channels_med && inp_large))) {
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large &&
channels_large) {
return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
}