Allow arbitrary first dimension in quantization kernels. (#458)

* Allow arbitrary first dim on qmm_t and qmv
* Allow arbitrary first dim on qmm and qvm
* Specialized aligned vs unaligned case
* Add more checks for valid quantizations
This commit is contained in:
Angelos Katharopoulos
2024-01-16 00:46:21 -08:00
committed by GitHub
parent f44c132f4a
commit c15fe3e61b
6 changed files with 206 additions and 66 deletions

View File

@@ -154,10 +154,13 @@ template <typename T, const int BM, const int BN, const int group_size, const in
// Loop over in_vec in blocks of colgroup
for (int i=0; i<in_vec_size; i+=BM) {
int offset = simd_lid + i;
bool thread_in_bounds = offset < in_vec_size;
// Load the vec to shared memory
threadgroup_barrier(mem_flags::mem_threadgroup);
if (simd_gid == 0) {
x_block[simd_lid] = x[simd_lid + i];
x_block[simd_lid] = (thread_in_bounds) ? x[offset] : 0;
}
// Load the scales and biases to shared memory
@@ -180,7 +183,7 @@ template <typename T, const int BM, const int BN, const int group_size, const in
bias = biases_block[simd_lid * groups_per_block + (simd_gid * el_per_int) / group_size];
// Load the matrix elements
w_local = w[(i + simd_lid) * out_vec_size_w];
w_local = (thread_in_bounds) ? w[offset * out_vec_size_w] : 0;
// Do all the work.
#pragma clang loop unroll(full)
@@ -206,7 +209,7 @@ template <typename T, const int BM, const int BN, const int group_size, const in
}
template <typename T, const int BM, const int BK, const int BN, const int group_size, const int bits>
template <typename T, const int BM, const int BK, const int BN, const int group_size, const int bits, const bool aligned_N>
[[kernel]] void qmm_t(
const device T* x [[buffer(0)]],
const device uint32_t* w [[buffer(1)]],
@@ -257,6 +260,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
// Make the x loader and mma operation
const short num_els = min(BM, M - y_row);
const short num_outs = min(BN, N - y_col);
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
mma_t mma_op(simd_gid, simd_lid);
@@ -292,21 +296,48 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
// Load the w tile
{
for (int wo=0; wo<w_els_per_thread; wo++) {
int offset = lid * w_els_per_thread + wo;
int offset_row = offset / (BK / el_per_int);
int offset_col = offset % (BK / el_per_int);
const device uint32_t * w_local = w + offset_row * K_w + offset_col;
threadgroup T * Ws_local = Ws + offset_row * BK + offset_col * el_per_int;
if (!aligned_N && num_outs < BN) {
for (int wo=0; wo<w_els_per_thread; wo++) {
int offset = lid * w_els_per_thread + wo;
int offset_row = offset / (BK / el_per_int);
int offset_col = offset % (BK / el_per_int);
const device uint32_t * w_local = w + offset_row * K_w + offset_col;
threadgroup T * Ws_local = Ws + offset_row * BK + offset_col * el_per_int;
uint32_t wi = *w_local;
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
if (y_col + offset_col < N) {
uint32_t wi = *w_local;
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
#pragma clang loop unroll(full)
for (int t=0; t<el_per_int; t++) {
Ws_local[t] = scale * static_cast<T>(wi & bitmask) + bias;
wi >>= bits;
#pragma clang loop unroll(full)
for (int t=0; t<el_per_int; t++) {
Ws_local[t] = scale * static_cast<T>(wi & bitmask) + bias;
wi >>= bits;
}
} else {
#pragma clang loop unroll(full)
for (int t=0; t<el_per_int; t++) {
Ws_local[t] = 0;
}
}
}
} else {
for (int wo=0; wo<w_els_per_thread; wo++) {
int offset = lid * w_els_per_thread + wo;
int offset_row = offset / (BK / el_per_int);
int offset_col = offset % (BK / el_per_int);
const device uint32_t * w_local = w + offset_row * K_w + offset_col;
threadgroup T * Ws_local = Ws + offset_row * BK + offset_col * el_per_int;
uint32_t wi = *w_local;
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
#pragma clang loop unroll(full)
for (int t=0; t<el_per_int; t++) {
Ws_local[t] = scale * static_cast<T>(wi & bitmask) + bias;
wi >>= bits;
}
}
}
}
@@ -324,8 +355,8 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
// Store results to device memory
threadgroup_barrier(mem_flags::mem_threadgroup);
if (num_els < BM) {
mma_op.store_result_safe(y, N, short2(BN, num_els));
if (num_els < BM || num_outs < BN) {
mma_op.store_result_safe(y, N, short2(num_outs, num_els));
} else {
mma_op.store_result(y, N);
}
@@ -417,21 +448,48 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
// Load the w tile
{
for (int wo=0; wo<w_els_per_thread; wo++) {
int offset = lid * w_els_per_thread + wo;
int offset_row = offset / (BN / el_per_int);
int offset_col = offset % (BN / el_per_int);
const device uint32_t * w_local = w + offset_row * N_w + offset_col;
threadgroup T * Ws_local = Ws + offset_row * BN + offset_col * el_per_int;
if (k + BK >= K) {
for (int wo=0; wo<w_els_per_thread; wo++) {
int offset = lid * w_els_per_thread + wo;
int offset_row = offset / (BN / el_per_int);
int offset_col = offset % (BN / el_per_int);
const device uint32_t * w_local = w + offset_row * N_w + offset_col;
threadgroup T * Ws_local = Ws + offset_row * BN + offset_col * el_per_int;
uint32_t wi = *w_local;
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
if (y_row + offset_row < K) {
uint32_t wi = *w_local;
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
#pragma clang loop unroll(full)
for (int t=0; t<el_per_int; t++) {
Ws_local[t] = scale * static_cast<T>(wi & bitmask) + bias;
wi >>= bits;
#pragma clang loop unroll(full)
for (int t=0; t<el_per_int; t++) {
Ws_local[t] = scale * static_cast<T>(wi & bitmask) + bias;
wi >>= bits;
}
} else {
#pragma clang loop unroll(full)
for (int t=0; t<el_per_int; t++) {
Ws_local[t] = 0;
}
}
}
} else {
for (int wo=0; wo<w_els_per_thread; wo++) {
int offset = lid * w_els_per_thread + wo;
int offset_row = offset / (BN / el_per_int);
int offset_col = offset % (BN / el_per_int);
const device uint32_t * w_local = w + offset_row * N_w + offset_col;
threadgroup T * Ws_local = Ws + offset_row * BN + offset_col * el_per_int;
uint32_t wi = *w_local;
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
#pragma clang loop unroll(full)
for (int t=0; t<el_per_int; t++) {
Ws_local[t] = scale * static_cast<T>(wi & bitmask) + bias;
wi >>= bits;
}
}
}
}
@@ -511,9 +569,9 @@ instantiate_qvm_types( 64, 2)
instantiate_qvm_types( 64, 4)
instantiate_qvm_types( 64, 8)
#define instantiate_qmm_t(name, itype, group_size, bits) \
template [[host_name("qmm_t_" #name "_gs_" #group_size "_b_" #bits)]] \
[[kernel]] void qmm_t<itype, 32, 64, 32, group_size, bits>( \
#define instantiate_qmm_t(name, itype, group_size, bits, aligned_N) \
template [[host_name("qmm_t_" #name "_gs_" #group_size "_b_" #bits "_alN_" #aligned_N)]] \
[[kernel]] void qmm_t<itype, 32, 64, 32, group_size, bits, aligned_N>( \
const device itype* x [[buffer(0)]], \
const device uint32_t* w [[buffer(1)]], \
const device itype* scales [[buffer(2)]], \
@@ -528,9 +586,12 @@ instantiate_qvm_types( 64, 8)
uint simd_lid [[thread_index_in_simdgroup]]);
#define instantiate_qmm_t_types(group_size, bits) \
instantiate_qmm_t(float32, float, group_size, bits) \
instantiate_qmm_t(float16, half, group_size, bits) \
instantiate_qmm_t(bfloat16, bfloat16_t, group_size, bits)
instantiate_qmm_t(float32, float, group_size, bits, false) \
instantiate_qmm_t(float16, half, group_size, bits, false) \
instantiate_qmm_t(bfloat16, bfloat16_t, group_size, bits, false) \
instantiate_qmm_t(float32, float, group_size, bits, true) \
instantiate_qmm_t(float16, half, group_size, bits, true) \
instantiate_qmm_t(bfloat16, bfloat16_t, group_size, bits, true)
instantiate_qmm_t_types(128, 2)
instantiate_qmm_t_types(128, 4)

View File

@@ -52,7 +52,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
int bo = 32;
int bo = std::min(32, O);
int bd = 32;
MTL::Size group_dims = MTL::Size(bd, bo, 1);
MTL::Size grid_dims = MTL::Size(1, O / bo, B);
@@ -72,7 +72,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
else {
std::ostringstream kname;
kname << "qmm_t_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
<< bits_;
<< bits_ << "_alN_" << std::boolalpha << ((O % 32) == 0);
// Encode and dispatch kernel
auto compute_encoder = d.get_command_encoder(s.index);
@@ -85,7 +85,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
int bn = 32;
int bk = 64;
MTL::Size group_dims = MTL::Size(32, wn, wm);
MTL::Size grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, 1);
MTL::Size grid_dims = MTL::Size((O + bn - 1) / bn, (B + bm - 1) / bm, 1);
set_array_buffer(compute_encoder, x, 0);
set_array_buffer(compute_encoder, w, 1);
@@ -110,10 +110,10 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
int bo = 32;
int bo = std::min(32, O);
int bd = 32;
MTL::Size group_dims = MTL::Size(bd, bo, 1);
MTL::Size grid_dims = MTL::Size(1, (w.shape(1) + bo - 1) / bo, B);
MTL::Size grid_dims = MTL::Size(1, (O + bo - 1) / bo, B);
set_array_buffer(compute_encoder, x, 0);
set_array_buffer(compute_encoder, w, 1);