mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 02:58:16 +08:00
Allow arbitrary first dimension in quantization kernels. (#458)
* Allow arbitrary first dim on qmm_t and qmv * Allow arbitrary first dim on qmm and qvm * Specialized aligned vs unaligned case * Add more checks for valid quantizations
This commit is contained in:

committed by
GitHub

parent
f44c132f4a
commit
c15fe3e61b
@@ -154,10 +154,13 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
||||
|
||||
// Loop over in_vec in blocks of colgroup
|
||||
for (int i=0; i<in_vec_size; i+=BM) {
|
||||
int offset = simd_lid + i;
|
||||
bool thread_in_bounds = offset < in_vec_size;
|
||||
|
||||
// Load the vec to shared memory
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_gid == 0) {
|
||||
x_block[simd_lid] = x[simd_lid + i];
|
||||
x_block[simd_lid] = (thread_in_bounds) ? x[offset] : 0;
|
||||
}
|
||||
|
||||
// Load the scales and biases to shared memory
|
||||
@@ -180,7 +183,7 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
||||
bias = biases_block[simd_lid * groups_per_block + (simd_gid * el_per_int) / group_size];
|
||||
|
||||
// Load the matrix elements
|
||||
w_local = w[(i + simd_lid) * out_vec_size_w];
|
||||
w_local = (thread_in_bounds) ? w[offset * out_vec_size_w] : 0;
|
||||
|
||||
// Do all the work.
|
||||
#pragma clang loop unroll(full)
|
||||
@@ -206,7 +209,7 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
||||
}
|
||||
|
||||
|
||||
template <typename T, const int BM, const int BK, const int BN, const int group_size, const int bits>
|
||||
template <typename T, const int BM, const int BK, const int BN, const int group_size, const int bits, const bool aligned_N>
|
||||
[[kernel]] void qmm_t(
|
||||
const device T* x [[buffer(0)]],
|
||||
const device uint32_t* w [[buffer(1)]],
|
||||
@@ -257,6 +260,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
||||
|
||||
// Make the x loader and mma operation
|
||||
const short num_els = min(BM, M - y_row);
|
||||
const short num_outs = min(BN, N - y_col);
|
||||
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
|
||||
mma_t mma_op(simd_gid, simd_lid);
|
||||
|
||||
@@ -292,21 +296,48 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
||||
|
||||
// Load the w tile
|
||||
{
|
||||
for (int wo=0; wo<w_els_per_thread; wo++) {
|
||||
int offset = lid * w_els_per_thread + wo;
|
||||
int offset_row = offset / (BK / el_per_int);
|
||||
int offset_col = offset % (BK / el_per_int);
|
||||
const device uint32_t * w_local = w + offset_row * K_w + offset_col;
|
||||
threadgroup T * Ws_local = Ws + offset_row * BK + offset_col * el_per_int;
|
||||
if (!aligned_N && num_outs < BN) {
|
||||
for (int wo=0; wo<w_els_per_thread; wo++) {
|
||||
int offset = lid * w_els_per_thread + wo;
|
||||
int offset_row = offset / (BK / el_per_int);
|
||||
int offset_col = offset % (BK / el_per_int);
|
||||
const device uint32_t * w_local = w + offset_row * K_w + offset_col;
|
||||
threadgroup T * Ws_local = Ws + offset_row * BK + offset_col * el_per_int;
|
||||
|
||||
uint32_t wi = *w_local;
|
||||
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||
if (y_col + offset_col < N) {
|
||||
uint32_t wi = *w_local;
|
||||
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for (int t=0; t<el_per_int; t++) {
|
||||
Ws_local[t] = scale * static_cast<T>(wi & bitmask) + bias;
|
||||
wi >>= bits;
|
||||
#pragma clang loop unroll(full)
|
||||
for (int t=0; t<el_per_int; t++) {
|
||||
Ws_local[t] = scale * static_cast<T>(wi & bitmask) + bias;
|
||||
wi >>= bits;
|
||||
}
|
||||
} else {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int t=0; t<el_per_int; t++) {
|
||||
Ws_local[t] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int wo=0; wo<w_els_per_thread; wo++) {
|
||||
int offset = lid * w_els_per_thread + wo;
|
||||
int offset_row = offset / (BK / el_per_int);
|
||||
int offset_col = offset % (BK / el_per_int);
|
||||
const device uint32_t * w_local = w + offset_row * K_w + offset_col;
|
||||
threadgroup T * Ws_local = Ws + offset_row * BK + offset_col * el_per_int;
|
||||
|
||||
uint32_t wi = *w_local;
|
||||
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for (int t=0; t<el_per_int; t++) {
|
||||
Ws_local[t] = scale * static_cast<T>(wi & bitmask) + bias;
|
||||
wi >>= bits;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -324,8 +355,8 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
||||
|
||||
// Store results to device memory
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (num_els < BM) {
|
||||
mma_op.store_result_safe(y, N, short2(BN, num_els));
|
||||
if (num_els < BM || num_outs < BN) {
|
||||
mma_op.store_result_safe(y, N, short2(num_outs, num_els));
|
||||
} else {
|
||||
mma_op.store_result(y, N);
|
||||
}
|
||||
@@ -417,21 +448,48 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
||||
|
||||
// Load the w tile
|
||||
{
|
||||
for (int wo=0; wo<w_els_per_thread; wo++) {
|
||||
int offset = lid * w_els_per_thread + wo;
|
||||
int offset_row = offset / (BN / el_per_int);
|
||||
int offset_col = offset % (BN / el_per_int);
|
||||
const device uint32_t * w_local = w + offset_row * N_w + offset_col;
|
||||
threadgroup T * Ws_local = Ws + offset_row * BN + offset_col * el_per_int;
|
||||
if (k + BK >= K) {
|
||||
for (int wo=0; wo<w_els_per_thread; wo++) {
|
||||
int offset = lid * w_els_per_thread + wo;
|
||||
int offset_row = offset / (BN / el_per_int);
|
||||
int offset_col = offset % (BN / el_per_int);
|
||||
const device uint32_t * w_local = w + offset_row * N_w + offset_col;
|
||||
threadgroup T * Ws_local = Ws + offset_row * BN + offset_col * el_per_int;
|
||||
|
||||
uint32_t wi = *w_local;
|
||||
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||
if (y_row + offset_row < K) {
|
||||
uint32_t wi = *w_local;
|
||||
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for (int t=0; t<el_per_int; t++) {
|
||||
Ws_local[t] = scale * static_cast<T>(wi & bitmask) + bias;
|
||||
wi >>= bits;
|
||||
#pragma clang loop unroll(full)
|
||||
for (int t=0; t<el_per_int; t++) {
|
||||
Ws_local[t] = scale * static_cast<T>(wi & bitmask) + bias;
|
||||
wi >>= bits;
|
||||
}
|
||||
} else {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int t=0; t<el_per_int; t++) {
|
||||
Ws_local[t] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int wo=0; wo<w_els_per_thread; wo++) {
|
||||
int offset = lid * w_els_per_thread + wo;
|
||||
int offset_row = offset / (BN / el_per_int);
|
||||
int offset_col = offset % (BN / el_per_int);
|
||||
const device uint32_t * w_local = w + offset_row * N_w + offset_col;
|
||||
threadgroup T * Ws_local = Ws + offset_row * BN + offset_col * el_per_int;
|
||||
|
||||
uint32_t wi = *w_local;
|
||||
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for (int t=0; t<el_per_int; t++) {
|
||||
Ws_local[t] = scale * static_cast<T>(wi & bitmask) + bias;
|
||||
wi >>= bits;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -511,9 +569,9 @@ instantiate_qvm_types( 64, 2)
|
||||
instantiate_qvm_types( 64, 4)
|
||||
instantiate_qvm_types( 64, 8)
|
||||
|
||||
#define instantiate_qmm_t(name, itype, group_size, bits) \
|
||||
template [[host_name("qmm_t_" #name "_gs_" #group_size "_b_" #bits)]] \
|
||||
[[kernel]] void qmm_t<itype, 32, 64, 32, group_size, bits>( \
|
||||
#define instantiate_qmm_t(name, itype, group_size, bits, aligned_N) \
|
||||
template [[host_name("qmm_t_" #name "_gs_" #group_size "_b_" #bits "_alN_" #aligned_N)]] \
|
||||
[[kernel]] void qmm_t<itype, 32, 64, 32, group_size, bits, aligned_N>( \
|
||||
const device itype* x [[buffer(0)]], \
|
||||
const device uint32_t* w [[buffer(1)]], \
|
||||
const device itype* scales [[buffer(2)]], \
|
||||
@@ -528,9 +586,12 @@ instantiate_qvm_types( 64, 8)
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_qmm_t_types(group_size, bits) \
|
||||
instantiate_qmm_t(float32, float, group_size, bits) \
|
||||
instantiate_qmm_t(float16, half, group_size, bits) \
|
||||
instantiate_qmm_t(bfloat16, bfloat16_t, group_size, bits)
|
||||
instantiate_qmm_t(float32, float, group_size, bits, false) \
|
||||
instantiate_qmm_t(float16, half, group_size, bits, false) \
|
||||
instantiate_qmm_t(bfloat16, bfloat16_t, group_size, bits, false) \
|
||||
instantiate_qmm_t(float32, float, group_size, bits, true) \
|
||||
instantiate_qmm_t(float16, half, group_size, bits, true) \
|
||||
instantiate_qmm_t(bfloat16, bfloat16_t, group_size, bits, true)
|
||||
|
||||
instantiate_qmm_t_types(128, 2)
|
||||
instantiate_qmm_t_types(128, 4)
|
||||
|
@@ -52,7 +52,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int bo = 32;
|
||||
int bo = std::min(32, O);
|
||||
int bd = 32;
|
||||
MTL::Size group_dims = MTL::Size(bd, bo, 1);
|
||||
MTL::Size grid_dims = MTL::Size(1, O / bo, B);
|
||||
@@ -72,7 +72,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
else {
|
||||
std::ostringstream kname;
|
||||
kname << "qmm_t_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_;
|
||||
<< bits_ << "_alN_" << std::boolalpha << ((O % 32) == 0);
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
@@ -85,7 +85,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
int bn = 32;
|
||||
int bk = 64;
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, 1);
|
||||
MTL::Size grid_dims = MTL::Size((O + bn - 1) / bn, (B + bm - 1) / bm, 1);
|
||||
|
||||
set_array_buffer(compute_encoder, x, 0);
|
||||
set_array_buffer(compute_encoder, w, 1);
|
||||
@@ -110,10 +110,10 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int bo = 32;
|
||||
int bo = std::min(32, O);
|
||||
int bd = 32;
|
||||
MTL::Size group_dims = MTL::Size(bd, bo, 1);
|
||||
MTL::Size grid_dims = MTL::Size(1, (w.shape(1) + bo - 1) / bo, B);
|
||||
MTL::Size grid_dims = MTL::Size(1, (O + bo - 1) / bo, B);
|
||||
|
||||
set_array_buffer(compute_encoder, x, 0);
|
||||
set_array_buffer(compute_encoder, w, 1);
|
||||
|
Reference in New Issue
Block a user