diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index e631d2d7b..ec4dfad91 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -259,10 +259,10 @@ void qmv_no_parallel_m( const Stream& s) { int B = out.size() / M / N; - int bn = 128; - // int bk = 32; - MTL::Size group_dims(2, 1, 1); - MTL::Size grid_dims((N + bn - 1) / bn, 1, B); + int bn = 8; + int bk = 32; + MTL::Size group_dims(bk, 2, 1); + MTL::Size grid_dims(1, (N + bn - 1) / bn, B); std::string kname; kname.reserve(64);