diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index e4f821b69..149f49e35 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -2162,7 +2162,7 @@ METAL_FUNC void affine_packed_qmv_fast_impl( uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; - constexpr int packs_per_thread = bits == 2 ? 1 : 2; + constexpr int packs_per_thread = 1; constexpr int num_simdgroups = 2; constexpr int results_per_simdgroup = 4; constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; @@ -2179,14 +2179,16 @@ METAL_FUNC void affine_packed_qmv_fast_impl( thread U result[results_per_simdgroup] = {0}; // Adjust positions - const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; + const int in_vec_size_w = + in_vec_size * results_per_simdgroup * bytes_per_pack / pack_factor; const int in_vec_size_g = in_vec_size * results_per_simdgroup * 2 / group_size; - const int scales_row = tid.x * num_simdgroups + simd_gid; - const int out_row = scales_row * results_per_simdgroup; + const int w_row = tid.x * num_simdgroups + simd_gid; + const int out_row = w_row * results_per_simdgroup; - ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; - scales += scales_row * in_vec_size_g + + ws += w_row * in_vec_size_w + + simd_lid * results_per_simdgroup * packs_per_thread * bytes_per_pack; + scales += w_row * in_vec_size_g + results_per_simdgroup * 2 * (simd_lid / scale_step_per_thread); x += tid.y * in_vec_size + simd_lid * values_per_thread; y += tid.y * out_vec_size + out_row; @@ -2194,18 +2196,16 @@ METAL_FUNC void affine_packed_qmv_fast_impl( for (int k = 0; k < in_vec_size; k += block_size) { U sum = load_vector(x, x_thread); - U sb[2 * results_per_simdgroup]; - for (int i = 0; i < 2 * results_per_simdgroup; i++) { - sb[i] = scales[i]; - } - for (int row = 0; row < results_per_simdgroup; row++) { - auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); result[row] += qdot( - wl, x_thread, sb[2 * row + 0], sb[2 * row + 1], sum); + ws + row * bytes_per_pack, + x_thread, + scales[2 * row + 0], + scales[2 * row + 1], + sum); } - ws += block_size * bytes_per_pack / pack_factor; + ws += results_per_simdgroup * block_size * bytes_per_pack / pack_factor; scales += block_size * 2 * results_per_simdgroup / group_size; x += block_size; } diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 3fc8fe70c..3577c33de 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -131,6 +131,7 @@ std::pair extract_quantized_matmul_dims( int scales_dims = scales.shape(-1) * group_size; if (type == QuantizationType::AffinePacked) { scales_dims /= 8; + weight_dims /= 4; } if (weight_dims != scales_dims) { @@ -147,8 +148,12 @@ std::pair extract_quantized_matmul_dims( int x_inner_dims = x.shape(-1); // Calculate the expanded w's dims - int w_inner_dims = (transpose) ? weight_dims : w.shape(-2); - int w_outer_dims = (transpose) ? w.shape(-2) : weight_dims; + int weight_dims_other = w.shape(-2); + if (type == QuantizationType::AffinePacked) { + weight_dims_other *= 4; + } + int w_inner_dims = (transpose) ? weight_dims : weight_dims_other; + int w_outer_dims = (transpose) ? weight_dims_other : weight_dims; if (w_inner_dims != x_inner_dims) { std::ostringstream msg; @@ -3778,20 +3783,25 @@ std::tuple> quantize( int bits /* = 4 */, QuantizationType type /* = QuantizationType::Affine */, StreamOrDevice s /* = {} */) { - auto [wq, scales, biases] = fast::affine_quantize(w, group_size, bits, s); + switch (type) { + case QuantizationType::Affine: + return fast::affine_quantize(w, group_size, bits, s); + case QuantizationType::AffinePacked: { + auto [wq, scales, biases] = fast::affine_quantize(w, group_size, bits, s); - // Pack scales and biases - if (type == QuantizationType::AffinePacked) { - scales = unflatten(scales, -2, {-1, 4, 1}, s); - biases = unflatten(biases, -2, {-1, 4, 1}, s); - scales = concatenate({scales, biases}, -2, s); - scales = flatten(scales, -3, -2, s); - scales = moveaxis(scales, -2, -1, s); - scales = flatten(scales, -2, -1, s); + scales = unflatten(scales, -2, {-1, 4, 1}, s); + biases = unflatten(biases, -2, {-1, 4, 1}, s); + scales = concatenate({scales, biases}, -2, s); + scales = flatten(scales, -3, -2, s); + scales = moveaxis(scales, -2, -1, s); + scales = flatten(scales, -2, -1, s); - return std::make_tuple(wq, scales, std::nullopt); - } else { - return std::make_tuple(wq, scales, biases); + wq = unflatten(wq, -2, {-1, 4}, s); + wq = moveaxis(wq, -2, -1, s); + wq = flatten(wq, -2, -1, s); + + return std::make_tuple(wq, scales, std::nullopt); + } } }