Improve qvm speed (#1140)

This commit is contained in:
Angelos Katharopoulos 2024-05-20 09:20:44 -07:00 committed by GitHub
parent 7e5674d8be
commit da83f899bb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 62 additions and 36 deletions

View File

@ -601,14 +601,18 @@ METAL_FUNC void qvm_impl(
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int num_simdgroups = 8;
constexpr int num_simdgroups = 2;
constexpr int pack_factor = 32 / bits;
constexpr int tn = 32 / pack_factor;
constexpr int blocksize = SIMD_SIZE;
typedef float U;
typedef struct {
uint32_t wi[tn];
} vec_w;
thread uint32_t w_local;
thread U result[pack_factor] = {0};
thread vec_w w_local;
thread U result[tn * pack_factor] = {0};
thread U scale = 1;
thread U bias = 0;
thread U x_local = 0;
@ -616,11 +620,12 @@ METAL_FUNC void qvm_impl(
// Adjust positions
const int out_vec_size_w = out_vec_size / pack_factor;
const int out_vec_size_g = out_vec_size / group_size;
int out_col = tid.x * (num_simdgroups * pack_factor) + simd_gid * pack_factor;
w += out_col / pack_factor;
scales += out_col / group_size;
biases += out_col / group_size;
x += tid.y * in_vec_size;
int out_col =
tid.x * (num_simdgroups * pack_factor * tn) + simd_gid * pack_factor * tn;
w += out_col / pack_factor + simd_lid * out_vec_size_w;
scales += out_col / group_size + simd_lid * out_vec_size_g;
biases += out_col / group_size + simd_lid * out_vec_size_g;
x += tid.y * in_vec_size + simd_lid;
y += tid.y * out_vec_size + out_col;
if (out_col >= out_vec_size) {
@ -628,40 +633,61 @@ METAL_FUNC void qvm_impl(
}
// Loop over in_vec in blocks of blocksize
int i = 0;
for (; i + blocksize <= in_vec_size; i += blocksize) {
x_local = x[i + simd_lid];
scale = scales[(i + simd_lid) * out_vec_size_g];
bias = biases[(i + simd_lid) * out_vec_size_g];
w_local = w[(i + simd_lid) * out_vec_size_w];
int remaining = in_vec_size % blocksize;
if (remaining == 0) {
for (int i = 0; i < in_vec_size; i += blocksize) {
x_local = *x;
scale = *scales;
bias = *biases;
w_local = *((device vec_w*)w);
qouter<U, pack_factor, bits>(
qouter<U, tn * pack_factor, bits>(
(thread uint8_t*)&w_local, x_local, scale, bias, result);
x += blocksize;
scales += blocksize * out_vec_size_g;
biases += blocksize * out_vec_size_g;
w += blocksize * out_vec_size_w;
}
if (static_cast<int>(i + simd_lid) < in_vec_size) {
x_local = x[i + simd_lid];
scale = scales[(i + simd_lid) * out_vec_size_g];
bias = biases[(i + simd_lid) * out_vec_size_g];
w_local = w[(i + simd_lid) * out_vec_size_w];
} else {
for (int i = blocksize; i < in_vec_size; i += blocksize) {
x_local = *x;
scale = *scales;
bias = *biases;
w_local = *((device vec_w*)w);
qouter<U, tn * pack_factor, bits>(
(thread uint8_t*)&w_local, x_local, scale, bias, result);
x += blocksize;
scales += blocksize * out_vec_size_g;
biases += blocksize * out_vec_size_g;
w += blocksize * out_vec_size_w;
}
if (static_cast<int>(simd_lid) < remaining) {
x_local = *x;
scale = *scales;
bias = *biases;
w_local = *((device vec_w*)w);
} else {
x_local = 0;
scale = 0;
bias = 0;
w_local = 0;
}
qouter<U, pack_factor, bits>(
qouter<U, tn * pack_factor, bits>(
(thread uint8_t*)&w_local, x_local, scale, bias, result);
}
// Accumulate in the simdgroup
#pragma clang loop unroll(full)
for (int k = 0; k < pack_factor; k++) {
for (int k = 0; k < tn * pack_factor; k++) {
result[k] = simd_sum(result[k]);
}
// Store the result
if (simd_lid == 0) {
#pragma clang loop unroll(full)
for (int k = 0; k < pack_factor; k++) {
for (int k = 0; k < tn * pack_factor; k++) {
y[k] = static_cast<T>(result[k]);
}
}

View File

@ -137,10 +137,10 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
int bo = 8;
int bo = 64;
int bd = 32;
MTL::Size group_dims = MTL::Size(bd, bo, 1);
MTL::Size grid_dims = MTL::Size((O + bo - 1) / bo, B, 1);
MTL::Size group_dims = MTL::Size(bd, 2, 1);
MTL::Size grid_dims = MTL::Size(O / bo, B, 1);
compute_encoder.set_input_array(x, 0);
compute_encoder.set_input_array(w, 1);
@ -393,10 +393,10 @@ void BlockSparseQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel);
int bo = 8;
int bo = 64;
int bd = 32;
MTL::Size group_dims = MTL::Size(bd, bo, 1);
MTL::Size grid_dims = MTL::Size((O + bo - 1) / bo, B, N);
MTL::Size group_dims = MTL::Size(bd, 2, 1);
MTL::Size grid_dims = MTL::Size(O / bo, B, N);
compute_encoder.set_input_array(x, 0);
compute_encoder.set_input_array(w, 1);