mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 10:48:09 +08:00
Quantize with groups of 32 (#511)
* allow quantize with group sizes of 32 * missing cpu dispatch * remove print * Fix qvm for group_size 32 --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
@@ -119,6 +119,12 @@ void _qmm_dispatch_typed(
|
||||
switch (bits) {
|
||||
case 2: {
|
||||
switch (group_size) {
|
||||
case 32:
|
||||
if (transposed_w) {
|
||||
return _qmm_t<T, 2, 32>(result, x, w, scales, biases, M, N, K);
|
||||
} else {
|
||||
return _qmm<T, 2, 32>(result, x, w, scales, biases, M, N, K);
|
||||
}
|
||||
case 64:
|
||||
if (transposed_w) {
|
||||
return _qmm_t<T, 2, 64>(result, x, w, scales, biases, M, N, K);
|
||||
@@ -135,6 +141,12 @@ void _qmm_dispatch_typed(
|
||||
}
|
||||
case 4: {
|
||||
switch (group_size) {
|
||||
case 32:
|
||||
if (transposed_w) {
|
||||
return _qmm_t<T, 4, 32>(result, x, w, scales, biases, M, N, K);
|
||||
} else {
|
||||
return _qmm<T, 4, 32>(result, x, w, scales, biases, M, N, K);
|
||||
}
|
||||
case 64:
|
||||
if (transposed_w) {
|
||||
return _qmm_t<T, 4, 64>(result, x, w, scales, biases, M, N, K);
|
||||
@@ -151,6 +163,12 @@ void _qmm_dispatch_typed(
|
||||
}
|
||||
case 8: {
|
||||
switch (group_size) {
|
||||
case 32:
|
||||
if (transposed_w) {
|
||||
return _qmm_t<T, 8, 32>(result, x, w, scales, biases, M, N, K);
|
||||
} else {
|
||||
return _qmm<T, 8, 32>(result, x, w, scales, biases, M, N, K);
|
||||
}
|
||||
case 64:
|
||||
if (transposed_w) {
|
||||
return _qmm_t<T, 8, 64>(result, x, w, scales, biases, M, N, K);
|
||||
|
@@ -142,10 +142,11 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
||||
// Adjust positions
|
||||
const int out_vec_size_w = out_vec_size / el_per_int;
|
||||
const int out_vec_size_g = out_vec_size / group_size;
|
||||
int out_col = (tid.y * BN + simd_gid) * el_per_int;
|
||||
int out_col_start = tid.y * (BN * el_per_int);
|
||||
int out_col = out_col_start + simd_gid * el_per_int;
|
||||
w += out_col / el_per_int;
|
||||
scales += out_col / group_size;
|
||||
biases += out_col / group_size;
|
||||
scales += out_col_start / group_size;
|
||||
biases += out_col_start / group_size;
|
||||
x += tid.z * in_vec_size;
|
||||
y += tid.z * out_vec_size + out_col;
|
||||
|
||||
@@ -155,26 +156,22 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
||||
|
||||
// Loop over in_vec in blocks of colgroup
|
||||
for (int i=0; i<in_vec_size; i+=BM) {
|
||||
int offset = simd_lid + i;
|
||||
bool thread_in_bounds = offset < in_vec_size;
|
||||
int offset_lid = simd_lid + i;
|
||||
int offset_gid = simd_gid + i;
|
||||
bool thread_in_bounds = offset_lid < in_vec_size;
|
||||
bool group_in_bounds = offset_gid < in_vec_size;
|
||||
|
||||
// Load the vec to shared memory
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_gid == 0) {
|
||||
x_block[simd_lid] = (thread_in_bounds) ? x[offset] : 0;
|
||||
x_block[simd_lid] = (thread_in_bounds) ? x[offset_lid] : 0;
|
||||
}
|
||||
|
||||
// Load the scales and biases to shared memory
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_gid == 0) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int j=0; j<groups_per_block; j++) {
|
||||
scales_block[simd_lid * groups_per_block + j] = scales[(i + simd_lid) * out_vec_size_g + j];
|
||||
}
|
||||
#pragma clang loop unroll(full)
|
||||
for (int j=0; j<groups_per_block; j++) {
|
||||
biases_block[simd_lid * groups_per_block + j] = biases[(i + simd_lid) * out_vec_size_g + j];
|
||||
}
|
||||
if (simd_lid < groups_per_block && group_in_bounds) {
|
||||
scales_block[simd_gid * groups_per_block + simd_lid] = scales[offset_gid * out_vec_size_g + simd_lid];
|
||||
biases_block[simd_gid * groups_per_block + simd_lid] = biases[offset_gid * out_vec_size_g + simd_lid];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
@@ -184,7 +181,7 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
||||
bias = biases_block[simd_lid * groups_per_block + (simd_gid * el_per_int) / group_size];
|
||||
|
||||
// Load the matrix elements
|
||||
w_local = (thread_in_bounds) ? w[offset * out_vec_size_w] : 0;
|
||||
w_local = (thread_in_bounds) ? w[offset_lid * out_vec_size_w] : 0;
|
||||
|
||||
// Do all the work.
|
||||
#pragma clang loop unroll(full)
|
||||
@@ -543,6 +540,9 @@ instantiate_qmv_types(128, 8)
|
||||
instantiate_qmv_types( 64, 2)
|
||||
instantiate_qmv_types( 64, 4)
|
||||
instantiate_qmv_types( 64, 8)
|
||||
instantiate_qmv_types( 32, 2)
|
||||
instantiate_qmv_types( 32, 4)
|
||||
instantiate_qmv_types( 32, 8)
|
||||
|
||||
#define instantiate_qvm(name, itype, group_size, bits) \
|
||||
template [[host_name("qvm_" #name "_gs_" #group_size "_b_" #bits)]] \
|
||||
@@ -570,6 +570,9 @@ instantiate_qvm_types(128, 8)
|
||||
instantiate_qvm_types( 64, 2)
|
||||
instantiate_qvm_types( 64, 4)
|
||||
instantiate_qvm_types( 64, 8)
|
||||
instantiate_qvm_types( 32, 2)
|
||||
instantiate_qvm_types( 32, 4)
|
||||
instantiate_qvm_types( 32, 8)
|
||||
|
||||
#define instantiate_qmm_t(name, itype, group_size, bits, aligned_N) \
|
||||
template [[host_name("qmm_t_" #name "_gs_" #group_size "_b_" #bits "_alN_" #aligned_N)]] \
|
||||
@@ -601,6 +604,9 @@ instantiate_qmm_t_types(128, 8)
|
||||
instantiate_qmm_t_types( 64, 2)
|
||||
instantiate_qmm_t_types( 64, 4)
|
||||
instantiate_qmm_t_types( 64, 8)
|
||||
instantiate_qmm_t_types( 32, 2)
|
||||
instantiate_qmm_t_types( 32, 4)
|
||||
instantiate_qmm_t_types( 32, 8)
|
||||
|
||||
#define instantiate_qmm_n(name, itype, group_size, bits) \
|
||||
template [[host_name("qmm_n_" #name "_gs_" #group_size "_b_" #bits)]] \
|
||||
@@ -629,3 +635,6 @@ instantiate_qmm_n_types(128, 8)
|
||||
instantiate_qmm_n_types( 64, 2)
|
||||
instantiate_qmm_n_types( 64, 4)
|
||||
instantiate_qmm_n_types( 64, 8)
|
||||
instantiate_qmm_n_types( 32, 2)
|
||||
instantiate_qmm_n_types( 32, 4)
|
||||
instantiate_qmm_n_types( 32, 8)
|
||||
|
@@ -2845,7 +2845,7 @@ std::tuple<array, array, array> quantize(
|
||||
int group_size /* = 64 */,
|
||||
int bits /* = 4 */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (group_size != 64 && group_size != 128) {
|
||||
if (group_size != 32 && group_size != 64 && group_size != 128) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] The requested group size " << group_size
|
||||
<< " is not supported. The supported group sizes are 64 and 128.";
|
||||
|
Reference in New Issue
Block a user