diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index 88700f645..5e70fc8fe 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -2158,7 +2158,18 @@ inline vec partial_qdot_vec(const thread U* x, vec w) { vec accum = 0; - if (bits == 4) { + if (bits == 2) { + for (int i = 0; i < 4; i++) { + auto ws = as_type>(w[i]); + for (int j = 0; j < 4; j++) { + accum[i] += + (x[4 * j + 0] * (ws[j] & 0x03) + x[4 * j + 1] * (ws[j] & 0x0c) + + x[4 * j + 2] * (ws[j] & 0x30) + x[4 * j + 3] * (ws[j] & 0xc0)); + } + } + } + + else if (bits == 4) { for (int i = 0; i < 4; i++) { auto ws = as_type>(w[i]); for (int j = 0; j < 2; j++) { @@ -2193,7 +2204,7 @@ METAL_FUNC void affine_packed_qmv_fast_impl( uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; - constexpr int packs_per_thread = 2; + constexpr int packs_per_thread = (bits == 2) ? 1 : 2; constexpr int num_simdgroups = 2; constexpr int results_per_simdgroup = 4; constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;