diff --git a/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h b/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h index b0b98d21a..2312e1ca6 100644 --- a/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +++ b/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h @@ -83,7 +83,7 @@ struct Conv2DInputBlockLoaderSmallChannels { const constant MLXConvParams<2>* params; const constant ImplicitGemmConv2DParams* gemm_params; - short weight_hw; + int weight_hw; const device T* src[n_rows];