mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 02:58:16 +08:00
Improve names of quantization arguments (#235)
* Change the default quantization group_size to 64 * Rename groups to group_size and width to bits
This commit is contained in:

committed by
GitHub

parent
57fe918cf8
commit
b3916cbf2b
@@ -14,7 +14,7 @@ using namespace metal;
|
||||
|
||||
MLX_MTL_CONST int SIMD_SIZE = 32;
|
||||
|
||||
template <typename T, const int BM, const int BN, const int groups, const int width>
|
||||
template <typename T, const int BM, const int BN, const int group_size, const int bits>
|
||||
[[kernel]] void qmv(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
@@ -30,10 +30,10 @@ template <typename T, const int BM, const int BN, const int groups, const int wi
|
||||
|
||||
static_assert(BN == SIMD_SIZE, "qmv expects BN to be equal to SIMD_SIZE");
|
||||
|
||||
constexpr int bitmask = (1 << width) - 1;
|
||||
constexpr int el_per_thread = 32 / width;
|
||||
constexpr int bitmask = (1 << bits) - 1;
|
||||
constexpr int el_per_thread = 32 / bits;
|
||||
constexpr int colgroup = BN * el_per_thread;
|
||||
constexpr int groups_per_block = colgroup / groups;
|
||||
constexpr int groups_per_block = colgroup / group_size;
|
||||
constexpr int simdgroups_fetching_vec = colgroup / SIMD_SIZE;
|
||||
|
||||
threadgroup T scales_block[BM * groups_per_block];
|
||||
@@ -48,7 +48,7 @@ template <typename T, const int BM, const int BN, const int groups, const int wi
|
||||
|
||||
// Adjust positions
|
||||
const int in_vec_size_w = in_vec_size / el_per_thread;
|
||||
const int in_vec_size_g = in_vec_size / groups;
|
||||
const int in_vec_size_g = in_vec_size / group_size;
|
||||
int out_row = tid.y * BM + simd_gid;
|
||||
w += out_row * in_vec_size_w;
|
||||
scales += out_row * in_vec_size_g;
|
||||
@@ -66,11 +66,11 @@ template <typename T, const int BM, const int BN, const int groups, const int wi
|
||||
if (simd_lid == 0) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int j=0; j<groups_per_block; j++) {
|
||||
scales_block[simd_gid * groups_per_block + j] = scales[i / groups + j];
|
||||
scales_block[simd_gid * groups_per_block + j] = scales[i / group_size + j];
|
||||
}
|
||||
#pragma clang loop unroll(full)
|
||||
for (int j=0; j<groups_per_block; j++) {
|
||||
biases_block[simd_gid * groups_per_block + j] = biases[i / groups + j];
|
||||
biases_block[simd_gid * groups_per_block + j] = biases[i / group_size + j];
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
@@ -80,8 +80,8 @@ template <typename T, const int BM, const int BN, const int groups, const int wi
|
||||
for (int j=0; j<el_per_thread; j++) {
|
||||
x_thread[j] = x_block[simd_lid*el_per_thread + j];
|
||||
}
|
||||
scale = scales_block[simd_gid * groups_per_block + simd_lid * el_per_thread / groups];
|
||||
bias = biases_block[simd_gid * groups_per_block + simd_lid * el_per_thread / groups];
|
||||
scale = scales_block[simd_gid * groups_per_block + simd_lid * el_per_thread / group_size];
|
||||
bias = biases_block[simd_gid * groups_per_block + simd_lid * el_per_thread / group_size];
|
||||
|
||||
// Load the matrix elements
|
||||
w_local = w[i / el_per_thread + simd_lid];
|
||||
@@ -90,7 +90,7 @@ template <typename T, const int BM, const int BN, const int groups, const int wi
|
||||
#pragma clang loop unroll(full)
|
||||
for (int k=0; k<el_per_thread; k++) {
|
||||
result += (scale * static_cast<T>(w_local & bitmask) + bias) * x_thread[k];
|
||||
w_local >>= width;
|
||||
w_local >>= bits;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -104,7 +104,7 @@ template <typename T, const int BM, const int BN, const int groups, const int wi
|
||||
}
|
||||
|
||||
|
||||
template <typename T, const int BM, const int BK, const int BN, const int groups, const int width>
|
||||
template <typename T, const int BM, const int BK, const int BN, const int group_size, const int bits>
|
||||
[[kernel]] void qmm_t(
|
||||
const device T* x [[buffer(0)]],
|
||||
const device uint32_t* w [[buffer(1)]],
|
||||
@@ -126,10 +126,10 @@ template <typename T, const int BM, const int BK, const int BN, const int groups
|
||||
|
||||
constexpr int WM = 2;
|
||||
constexpr int WN = 2;
|
||||
constexpr int bitmask = (1 << width) - 1;
|
||||
constexpr int el_per_int = 32 / width;
|
||||
constexpr int bitmask = (1 << bits) - 1;
|
||||
constexpr int el_per_int = 32 / bits;
|
||||
constexpr int ints_per_block = BK / el_per_int;
|
||||
constexpr int groups_per_block = (BK / groups > 0) ? (BK / groups) : 1;
|
||||
constexpr int groups_per_block = (BK / group_size > 0) ? (BK / group_size) : 1;
|
||||
constexpr int groups_per_simd = BN / (WM * WN);
|
||||
constexpr int w_els_per_thread = (BN * BK / el_per_int) / (SIMD_SIZE * WM * WN);
|
||||
|
||||
@@ -145,7 +145,7 @@ template <typename T, const int BM, const int BK, const int BN, const int groups
|
||||
|
||||
// Set the block
|
||||
const int K_w = K / el_per_int;
|
||||
const int K_g = K / groups;
|
||||
const int K_g = K / group_size;
|
||||
const int y_row = tid.y * BM;
|
||||
const int y_col = tid.x * BN;
|
||||
x += y_row * K;
|
||||
@@ -172,8 +172,8 @@ template <typename T, const int BM, const int BK, const int BN, const int groups
|
||||
if (simd_lid == 0) {
|
||||
threadgroup T *scales_block_local = scales_block + lidy * groups_per_block * groups_per_simd;
|
||||
threadgroup T *biases_block_local = biases_block + lidy * groups_per_block * groups_per_simd;
|
||||
const device T *scales_local = scales + lidy * groups_per_simd * K_g + k / groups;
|
||||
const device T *biases_local = biases + lidy * groups_per_simd * K_g + k / groups;
|
||||
const device T *scales_local = scales + lidy * groups_per_simd * K_g + k / group_size;
|
||||
const device T *biases_local = biases + lidy * groups_per_simd * K_g + k / group_size;
|
||||
#pragma clang loop unroll(full)
|
||||
for (int gs=0; gs<groups_per_simd; gs++) {
|
||||
#pragma clang loop unroll(full)
|
||||
@@ -199,13 +199,13 @@ template <typename T, const int BM, const int BK, const int BN, const int groups
|
||||
threadgroup T * Ws_local = Ws + offset_row * BK + offset_col * el_per_int;
|
||||
|
||||
uint32_t wi = *w_local;
|
||||
T scale = scales_block[offset_row * groups_per_block + offset_col / (groups / el_per_int)];
|
||||
T bias = biases_block[offset_row * groups_per_block + offset_col / (groups / el_per_int)];
|
||||
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for (int t=0; t<el_per_int; t++) {
|
||||
Ws_local[t] = scale * static_cast<T>(wi & bitmask) + bias;
|
||||
wi >>= width;
|
||||
wi >>= bits;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -231,9 +231,9 @@ template <typename T, const int BM, const int BK, const int BN, const int groups
|
||||
}
|
||||
|
||||
|
||||
#define instantiate_qmv(name, itype, groups, width) \
|
||||
template [[host_name("qmv_n_" #name "_groups_" #groups "_width_" #width)]] \
|
||||
[[kernel]] void qmv<itype, 32, 32, groups, width>( \
|
||||
#define instantiate_qmv(name, itype, group_size, bits) \
|
||||
template [[host_name("qmv_n_" #name "_gs_" #group_size "_b_" #bits)]] \
|
||||
[[kernel]] void qmv<itype, 32, 32, group_size, bits>( \
|
||||
const device uint32_t* w [[buffer(0)]], \
|
||||
const device itype* scales [[buffer(1)]], \
|
||||
const device itype* biases [[buffer(2)]], \
|
||||
@@ -246,10 +246,10 @@ template <typename T, const int BM, const int BK, const int BN, const int groups
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_qmv_types(groups, width) \
|
||||
instantiate_qmv(float32, float, groups, width) \
|
||||
instantiate_qmv(float16, half, groups, width) \
|
||||
instantiate_qmv(bfloat16, bfloat16_t, groups, width)
|
||||
#define instantiate_qmv_types(group_size, bits) \
|
||||
instantiate_qmv(float32, float, group_size, bits) \
|
||||
instantiate_qmv(float16, half, group_size, bits) \
|
||||
instantiate_qmv(bfloat16, bfloat16_t, group_size, bits)
|
||||
|
||||
instantiate_qmv_types(128, 2)
|
||||
instantiate_qmv_types(128, 4)
|
||||
@@ -258,9 +258,9 @@ instantiate_qmv_types( 64, 2)
|
||||
instantiate_qmv_types( 64, 4)
|
||||
instantiate_qmv_types( 64, 8)
|
||||
|
||||
#define instantiate_qmm_t(name, itype, groups, width) \
|
||||
template [[host_name("qmm_t_" #name "_groups_" #groups "_width_" #width)]] \
|
||||
[[kernel]] void qmm_t<itype, 32, 64, 32, groups, width>( \
|
||||
#define instantiate_qmm_t(name, itype, group_size, bits) \
|
||||
template [[host_name("qmm_t_" #name "_gs_" #group_size "_b_" #bits)]] \
|
||||
[[kernel]] void qmm_t<itype, 32, 64, 32, group_size, bits>( \
|
||||
const device itype* x [[buffer(0)]], \
|
||||
const device uint32_t* w [[buffer(1)]], \
|
||||
const device itype* scales [[buffer(2)]], \
|
||||
@@ -274,10 +274,10 @@ instantiate_qmv_types( 64, 8)
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_qmm_t_types(groups, width) \
|
||||
instantiate_qmm_t(float32, float, groups, width) \
|
||||
instantiate_qmm_t(float16, half, groups, width) \
|
||||
instantiate_qmm_t(bfloat16, bfloat16_t, groups, width)
|
||||
#define instantiate_qmm_t_types(group_size, bits) \
|
||||
instantiate_qmm_t(float32, float, group_size, bits) \
|
||||
instantiate_qmm_t(float16, half, group_size, bits) \
|
||||
instantiate_qmm_t(bfloat16, bfloat16_t, group_size, bits)
|
||||
|
||||
instantiate_qmm_t_types(128, 2)
|
||||
instantiate_qmm_t_types(128, 4)
|
||||
|
@@ -58,7 +58,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
if (B == 1) {
|
||||
std::ostringstream kname;
|
||||
kname << "qmv_" << (w_transposed ? "n_" : "t_") << type_to_name(out)
|
||||
<< "_groups_" << groups_ << "_width_" << width_;
|
||||
<< "_gs_" << group_size_ << "_b_" << bits_;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
@@ -87,7 +87,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
else {
|
||||
std::ostringstream kname;
|
||||
kname << "qmm_" << (w_transposed ? "t_" : "n_") << type_to_name(out)
|
||||
<< "_groups_" << groups_ << "_width_" << width_;
|
||||
<< "_gs_" << group_size_ << "_b_" << bits_;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
|
Reference in New Issue
Block a user