From 015c247393045dd6761d67d1c0108a1fb3e640e8 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 28 Oct 2024 11:13:44 -0700 Subject: [PATCH] change wino dispatch conditoin (#1534) --- mlx/backend/metal/conv.cpp | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index 1cc8a2a76..0a2c62074 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -752,10 +752,6 @@ void conv_2D_gpu( bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1; bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1; - bool inp_large = (conv_params.in_strides[0] >= 1ul << 18); - bool channels_large = (conv_params.C + conv_params.O) >= 512; - bool channels_med = (conv_params.C + conv_params.O) >= 256; - if (groups > 1) { const int C_per_group = conv_params.C / groups; const int O_per_group = conv_params.O / groups; @@ -769,10 +765,13 @@ void conv_2D_gpu( } // Direct to winograd conv + bool inp_large = + (conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 1ul << 12; + bool channels_large = (conv_params.C + conv_params.O) >= 256; if (!flip && is_stride_one && is_kdil_one && is_idil_one && conv_params.wS[0] == 3 && conv_params.wS[1] == 3 && - conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && - (channels_large || (channels_med && inp_large))) { + conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large && + channels_large) { return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies); }