diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index 8328ab16a..2ca06d232 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -2504,12 +2504,149 @@ struct AffinePackedQuantizedBlockLoader { group_step_cnt++; if (group_step_cnt == group_steps) { group_step_cnt = 0; - scales += 8; - biases += 8; + scales += (2 * row_pack_factor); + biases += (2 * row_pack_factor); } } else { - scales += 8; - biases += 8; + scales += (2 * row_pack_factor); + biases += (2 * row_pack_factor); + } + } else { + scales += group_stride; + biases += group_stride; + } + } +}; + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size, + short group_size, + short bits> +struct AffineScalesPackedQuantizedBlockLoader { + static_assert( + BCOLS <= group_size, + "The group size should be larger than the columns"); + static_assert( + group_size % BCOLS == 0, + "The group size should be divisible by the columns"); + static_assert( + bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, + "Template undefined for bits not in {2, 3, 4, 6, 8}"); + + MLX_MTL_CONST short bytes_per_pack = (bits & (bits - 1)) ? 3 : 4; + MLX_MTL_CONST short pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; + MLX_MTL_CONST short row_pack_factor = 2; + MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor; + MLX_MTL_CONST short BROWS_PACKED = BROWS / row_pack_factor; + MLX_MTL_CONST short TOTAL_READS = BCOLS * BROWS / pack_factor; + MLX_MTL_CONST short n_reads = + (TOTAL_READS < tgp_size) ? 1 : TOTAL_READS / tgp_size; + MLX_MTL_CONST short group_steps = group_size / BCOLS; + + const int src_ld; + const int tile_stride; + short group_step_cnt; + const int group_stride; + + const short thread_idx; + const short bi; + const short bj; + const short bii; + + const device uint8_t* src; + const device T* scales; + const device T* biases; + threadgroup T* dst; + + AffineScalesPackedQuantizedBlockLoader( + const device uint32_t* src_, + const device T* scales_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(src_ld_), + tile_stride( + reduction_dim ? BCOLS_PACKED * bytes_per_pack + : BROWS * src_ld * bytes_per_pack / pack_factor), + group_step_cnt(0), + group_stride(BROWS_PACKED * 2 * src_ld / group_size), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(n_reads * thread_idx / BCOLS_PACKED), + bj((n_reads * thread_idx) % BCOLS_PACKED), + bii(bi / row_pack_factor), + src(((const device uint8_t*)src_) + + bi * src_ld * bytes_per_pack / pack_factor + bj * bytes_per_pack), + scales( + scales_ + bii * 2 * src_ld * row_pack_factor / group_size + + bi % row_pack_factor), + biases(scales + row_pack_factor), + dst(dst_ + bi * dst_ld + bj * pack_factor) {} + + void load_unsafe() const { + if (bits == 2 && TOTAL_READS < tgp_size && bi >= BROWS) { + return; + } + + T scale = *scales; + T bias = *biases; + for (int i = 0; i < n_reads; i++) { + dequantize( + (const device uint8_t*)(src + bytes_per_pack * i), + scale, + bias, + dst + i * pack_factor); + } + } + + void load_safe(short2 src_tile_dim) const { + if (TOTAL_READS < tgp_size && bi >= BROWS) { + return; + } + + if (reduction_dim == 1 && bii >= src_tile_dim.y) { + for (int i = 0; i < n_reads * pack_factor; i++) { + dst[i] = T(0); + } + return; + } + + if (reduction_dim == 0 && bii >= src_tile_dim.x) { + for (int i = 0; i < n_reads * pack_factor; i++) { + dst[i] = T(0); + } + return; + } + + for (int i = 0; i < n_reads; i++) { + T scale = scales[i]; + T bias = biases[i]; + dequantize( + (const device uint8_t*)(src + bytes_per_pack * i * src_ld), + scale, + bias, + dst + i * dst_ld); + } + } + + void next() { + src += tile_stride; + if (reduction_dim == 1) { + if (group_steps > 1) { + group_step_cnt++; + if (group_step_cnt == group_steps) { + group_step_cnt = 0; + scales += (2 * row_pack_factor); + biases += (2 * row_pack_factor); + } + } else { + scales += (2 * row_pack_factor); + biases += (2 * row_pack_factor); } } else { scales += group_stride; @@ -2545,10 +2682,11 @@ METAL_FUNC void affine_packed_qmm_t_impl( (void)lid; + constexpr bool power_of_2_bits = (bits & (bits - 1)) == 0; constexpr int WM = 2; constexpr int WN = 2; constexpr int pack_factor = 32 / bits; - constexpr int row_pack_factor = 4; + constexpr int row_pack_factor = (power_of_2_bits) ? 4 : 2; constexpr int BK_padded = (BK + 16 / sizeof(T)); // Instantiate the appropriate BlockMMA and Loader @@ -2556,7 +2694,7 @@ METAL_FUNC void affine_packed_qmm_t_impl( BlockMMA; using loader_x_t = mlx::steel::BlockLoader; - using loader_w_t = AffinePackedQuantizedBlockLoader< + using loader_fully_packed_t = AffinePackedQuantizedBlockLoader< T, BN, BK, @@ -2565,16 +2703,30 @@ METAL_FUNC void affine_packed_qmm_t_impl( WM * WN * SIMD_SIZE, group_size, bits>; + using loader_scales_packed_t = AffineScalesPackedQuantizedBlockLoader< + T, + BN, + BK, + BK_padded, + 1, + WM * WN * SIMD_SIZE, + group_size, + bits>; + using loader_w_t = typename ConditionalType< + power_of_2_bits, + loader_fully_packed_t, + loader_scales_packed_t>::type; // Set the block - const int K_w = K * row_pack_factor / pack_factor; + const int K_w = + (power_of_2_bits) ? K * row_pack_factor / pack_factor : K * bits / 32; const int K_g = K * 2 * row_pack_factor / group_size; const int y_row = tid.y * BM; const int y_col = tid.x * BN; const int packed_y_col = tid.x * (BN / row_pack_factor); x += y_row * K; - w += packed_y_col * K_w; + w += (power_of_2_bits) ? packed_y_col * K_w : y_col * K_w; scales += packed_y_col * K_g; y += y_row * N + y_col; @@ -2692,9 +2844,6 @@ template < s_strides, tid); } - if (bits & (bits - 1)) { - } else { - affine_packed_qmm_t_impl( - w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); - } + affine_packed_qmm_t_impl( + w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); }