diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index 1cc8a2a76..0a2c62074 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -752,10 +752,6 @@ void conv_2D_gpu( bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1; bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1; - bool inp_large = (conv_params.in_strides[0] >= 1ul << 18); - bool channels_large = (conv_params.C + conv_params.O) >= 512; - bool channels_med = (conv_params.C + conv_params.O) >= 256; - if (groups > 1) { const int C_per_group = conv_params.C / groups; const int O_per_group = conv_params.O / groups; @@ -769,10 +765,13 @@ void conv_2D_gpu( } // Direct to winograd conv + bool inp_large = + (conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 1ul << 12; + bool channels_large = (conv_params.C + conv_params.O) >= 256; if (!flip && is_stride_one && is_kdil_one && is_idil_one && conv_params.wS[0] == 3 && conv_params.wS[1] == 3 && - conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && - (channels_large || (channels_med && inp_large))) { + conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large && + channels_large) { return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies); }