mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Add the 3bit packed qmm_t
This commit is contained in:
parent
d75a509234
commit
c02e14c264
@ -2504,12 +2504,149 @@ struct AffinePackedQuantizedBlockLoader {
|
||||
group_step_cnt++;
|
||||
if (group_step_cnt == group_steps) {
|
||||
group_step_cnt = 0;
|
||||
scales += 8;
|
||||
biases += 8;
|
||||
scales += (2 * row_pack_factor);
|
||||
biases += (2 * row_pack_factor);
|
||||
}
|
||||
} else {
|
||||
scales += 8;
|
||||
biases += 8;
|
||||
scales += (2 * row_pack_factor);
|
||||
biases += (2 * row_pack_factor);
|
||||
}
|
||||
} else {
|
||||
scales += group_stride;
|
||||
biases += group_stride;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
short BROWS,
|
||||
short BCOLS,
|
||||
short dst_ld,
|
||||
short reduction_dim,
|
||||
short tgp_size,
|
||||
short group_size,
|
||||
short bits>
|
||||
struct AffineScalesPackedQuantizedBlockLoader {
|
||||
static_assert(
|
||||
BCOLS <= group_size,
|
||||
"The group size should be larger than the columns");
|
||||
static_assert(
|
||||
group_size % BCOLS == 0,
|
||||
"The group size should be divisible by the columns");
|
||||
static_assert(
|
||||
bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8,
|
||||
"Template undefined for bits not in {2, 3, 4, 6, 8}");
|
||||
|
||||
MLX_MTL_CONST short bytes_per_pack = (bits & (bits - 1)) ? 3 : 4;
|
||||
MLX_MTL_CONST short pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
|
||||
MLX_MTL_CONST short row_pack_factor = 2;
|
||||
MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor;
|
||||
MLX_MTL_CONST short BROWS_PACKED = BROWS / row_pack_factor;
|
||||
MLX_MTL_CONST short TOTAL_READS = BCOLS * BROWS / pack_factor;
|
||||
MLX_MTL_CONST short n_reads =
|
||||
(TOTAL_READS < tgp_size) ? 1 : TOTAL_READS / tgp_size;
|
||||
MLX_MTL_CONST short group_steps = group_size / BCOLS;
|
||||
|
||||
const int src_ld;
|
||||
const int tile_stride;
|
||||
short group_step_cnt;
|
||||
const int group_stride;
|
||||
|
||||
const short thread_idx;
|
||||
const short bi;
|
||||
const short bj;
|
||||
const short bii;
|
||||
|
||||
const device uint8_t* src;
|
||||
const device T* scales;
|
||||
const device T* biases;
|
||||
threadgroup T* dst;
|
||||
|
||||
AffineScalesPackedQuantizedBlockLoader(
|
||||
const device uint32_t* src_,
|
||||
const device T* scales_,
|
||||
const int src_ld_,
|
||||
threadgroup T* dst_,
|
||||
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
||||
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
||||
: src_ld(src_ld_),
|
||||
tile_stride(
|
||||
reduction_dim ? BCOLS_PACKED * bytes_per_pack
|
||||
: BROWS * src_ld * bytes_per_pack / pack_factor),
|
||||
group_step_cnt(0),
|
||||
group_stride(BROWS_PACKED * 2 * src_ld / group_size),
|
||||
thread_idx(simd_group_id * 32 + simd_lane_id),
|
||||
bi(n_reads * thread_idx / BCOLS_PACKED),
|
||||
bj((n_reads * thread_idx) % BCOLS_PACKED),
|
||||
bii(bi / row_pack_factor),
|
||||
src(((const device uint8_t*)src_) +
|
||||
bi * src_ld * bytes_per_pack / pack_factor + bj * bytes_per_pack),
|
||||
scales(
|
||||
scales_ + bii * 2 * src_ld * row_pack_factor / group_size +
|
||||
bi % row_pack_factor),
|
||||
biases(scales + row_pack_factor),
|
||||
dst(dst_ + bi * dst_ld + bj * pack_factor) {}
|
||||
|
||||
void load_unsafe() const {
|
||||
if (bits == 2 && TOTAL_READS < tgp_size && bi >= BROWS) {
|
||||
return;
|
||||
}
|
||||
|
||||
T scale = *scales;
|
||||
T bias = *biases;
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
dequantize<T, pack_factor, bits>(
|
||||
(const device uint8_t*)(src + bytes_per_pack * i),
|
||||
scale,
|
||||
bias,
|
||||
dst + i * pack_factor);
|
||||
}
|
||||
}
|
||||
|
||||
void load_safe(short2 src_tile_dim) const {
|
||||
if (TOTAL_READS < tgp_size && bi >= BROWS) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (reduction_dim == 1 && bii >= src_tile_dim.y) {
|
||||
for (int i = 0; i < n_reads * pack_factor; i++) {
|
||||
dst[i] = T(0);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (reduction_dim == 0 && bii >= src_tile_dim.x) {
|
||||
for (int i = 0; i < n_reads * pack_factor; i++) {
|
||||
dst[i] = T(0);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
T scale = scales[i];
|
||||
T bias = biases[i];
|
||||
dequantize<T, pack_factor, bits>(
|
||||
(const device uint8_t*)(src + bytes_per_pack * i * src_ld),
|
||||
scale,
|
||||
bias,
|
||||
dst + i * dst_ld);
|
||||
}
|
||||
}
|
||||
|
||||
void next() {
|
||||
src += tile_stride;
|
||||
if (reduction_dim == 1) {
|
||||
if (group_steps > 1) {
|
||||
group_step_cnt++;
|
||||
if (group_step_cnt == group_steps) {
|
||||
group_step_cnt = 0;
|
||||
scales += (2 * row_pack_factor);
|
||||
biases += (2 * row_pack_factor);
|
||||
}
|
||||
} else {
|
||||
scales += (2 * row_pack_factor);
|
||||
biases += (2 * row_pack_factor);
|
||||
}
|
||||
} else {
|
||||
scales += group_stride;
|
||||
@ -2545,10 +2682,11 @@ METAL_FUNC void affine_packed_qmm_t_impl(
|
||||
|
||||
(void)lid;
|
||||
|
||||
constexpr bool power_of_2_bits = (bits & (bits - 1)) == 0;
|
||||
constexpr int WM = 2;
|
||||
constexpr int WN = 2;
|
||||
constexpr int pack_factor = 32 / bits;
|
||||
constexpr int row_pack_factor = 4;
|
||||
constexpr int row_pack_factor = (power_of_2_bits) ? 4 : 2;
|
||||
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
||||
|
||||
// Instantiate the appropriate BlockMMA and Loader
|
||||
@ -2556,7 +2694,7 @@ METAL_FUNC void affine_packed_qmm_t_impl(
|
||||
BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK_padded, BK_padded>;
|
||||
using loader_x_t =
|
||||
mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>;
|
||||
using loader_w_t = AffinePackedQuantizedBlockLoader<
|
||||
using loader_fully_packed_t = AffinePackedQuantizedBlockLoader<
|
||||
T,
|
||||
BN,
|
||||
BK,
|
||||
@ -2565,16 +2703,30 @@ METAL_FUNC void affine_packed_qmm_t_impl(
|
||||
WM * WN * SIMD_SIZE,
|
||||
group_size,
|
||||
bits>;
|
||||
using loader_scales_packed_t = AffineScalesPackedQuantizedBlockLoader<
|
||||
T,
|
||||
BN,
|
||||
BK,
|
||||
BK_padded,
|
||||
1,
|
||||
WM * WN * SIMD_SIZE,
|
||||
group_size,
|
||||
bits>;
|
||||
using loader_w_t = typename ConditionalType<
|
||||
power_of_2_bits,
|
||||
loader_fully_packed_t,
|
||||
loader_scales_packed_t>::type;
|
||||
|
||||
// Set the block
|
||||
const int K_w = K * row_pack_factor / pack_factor;
|
||||
const int K_w =
|
||||
(power_of_2_bits) ? K * row_pack_factor / pack_factor : K * bits / 32;
|
||||
const int K_g = K * 2 * row_pack_factor / group_size;
|
||||
const int y_row = tid.y * BM;
|
||||
const int y_col = tid.x * BN;
|
||||
const int packed_y_col = tid.x * (BN / row_pack_factor);
|
||||
|
||||
x += y_row * K;
|
||||
w += packed_y_col * K_w;
|
||||
w += (power_of_2_bits) ? packed_y_col * K_w : y_col * K_w;
|
||||
scales += packed_y_col * K_g;
|
||||
y += y_row * N + y_col;
|
||||
|
||||
@ -2692,9 +2844,6 @@ template <
|
||||
s_strides,
|
||||
tid);
|
||||
}
|
||||
if (bits & (bits - 1)) {
|
||||
} else {
|
||||
affine_packed_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
|
||||
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
|
||||
}
|
||||
affine_packed_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
|
||||
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user