Improve qvm speed (#1140)

This commit is contained in:
Angelos Katharopoulos 2024-05-20 09:20:44 -07:00 committed by GitHub
parent 7e5674d8be
commit da83f899bb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 62 additions and 36 deletions

View File

@ -601,14 +601,18 @@ METAL_FUNC void qvm_impl(
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) { uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int num_simdgroups = 8; constexpr int num_simdgroups = 2;
constexpr int pack_factor = 32 / bits; constexpr int pack_factor = 32 / bits;
constexpr int tn = 32 / pack_factor;
constexpr int blocksize = SIMD_SIZE; constexpr int blocksize = SIMD_SIZE;
typedef float U; typedef float U;
typedef struct {
uint32_t wi[tn];
} vec_w;
thread uint32_t w_local; thread vec_w w_local;
thread U result[pack_factor] = {0}; thread U result[tn * pack_factor] = {0};
thread U scale = 1; thread U scale = 1;
thread U bias = 0; thread U bias = 0;
thread U x_local = 0; thread U x_local = 0;
@ -616,11 +620,12 @@ METAL_FUNC void qvm_impl(
// Adjust positions // Adjust positions
const int out_vec_size_w = out_vec_size / pack_factor; const int out_vec_size_w = out_vec_size / pack_factor;
const int out_vec_size_g = out_vec_size / group_size; const int out_vec_size_g = out_vec_size / group_size;
int out_col = tid.x * (num_simdgroups * pack_factor) + simd_gid * pack_factor; int out_col =
w += out_col / pack_factor; tid.x * (num_simdgroups * pack_factor * tn) + simd_gid * pack_factor * tn;
scales += out_col / group_size; w += out_col / pack_factor + simd_lid * out_vec_size_w;
biases += out_col / group_size; scales += out_col / group_size + simd_lid * out_vec_size_g;
x += tid.y * in_vec_size; biases += out_col / group_size + simd_lid * out_vec_size_g;
x += tid.y * in_vec_size + simd_lid;
y += tid.y * out_vec_size + out_col; y += tid.y * out_vec_size + out_col;
if (out_col >= out_vec_size) { if (out_col >= out_vec_size) {
@ -628,40 +633,61 @@ METAL_FUNC void qvm_impl(
} }
// Loop over in_vec in blocks of blocksize // Loop over in_vec in blocks of blocksize
int i = 0; int remaining = in_vec_size % blocksize;
for (; i + blocksize <= in_vec_size; i += blocksize) { if (remaining == 0) {
x_local = x[i + simd_lid]; for (int i = 0; i < in_vec_size; i += blocksize) {
scale = scales[(i + simd_lid) * out_vec_size_g]; x_local = *x;
bias = biases[(i + simd_lid) * out_vec_size_g]; scale = *scales;
w_local = w[(i + simd_lid) * out_vec_size_w]; bias = *biases;
w_local = *((device vec_w*)w);
qouter<U, pack_factor, bits>( qouter<U, tn * pack_factor, bits>(
(thread uint8_t*)&w_local, x_local, scale, bias, result);
x += blocksize;
scales += blocksize * out_vec_size_g;
biases += blocksize * out_vec_size_g;
w += blocksize * out_vec_size_w;
}
} else {
for (int i = blocksize; i < in_vec_size; i += blocksize) {
x_local = *x;
scale = *scales;
bias = *biases;
w_local = *((device vec_w*)w);
qouter<U, tn * pack_factor, bits>(
(thread uint8_t*)&w_local, x_local, scale, bias, result);
x += blocksize;
scales += blocksize * out_vec_size_g;
biases += blocksize * out_vec_size_g;
w += blocksize * out_vec_size_w;
}
if (static_cast<int>(simd_lid) < remaining) {
x_local = *x;
scale = *scales;
bias = *biases;
w_local = *((device vec_w*)w);
} else {
x_local = 0;
scale = 0;
bias = 0;
}
qouter<U, tn * pack_factor, bits>(
(thread uint8_t*)&w_local, x_local, scale, bias, result); (thread uint8_t*)&w_local, x_local, scale, bias, result);
} }
if (static_cast<int>(i + simd_lid) < in_vec_size) {
x_local = x[i + simd_lid];
scale = scales[(i + simd_lid) * out_vec_size_g];
bias = biases[(i + simd_lid) * out_vec_size_g];
w_local = w[(i + simd_lid) * out_vec_size_w];
} else {
x_local = 0;
scale = 0;
bias = 0;
w_local = 0;
}
qouter<U, pack_factor, bits>(
(thread uint8_t*)&w_local, x_local, scale, bias, result);
// Accumulate in the simdgroup // Accumulate in the simdgroup
#pragma clang loop unroll(full) #pragma clang loop unroll(full)
for (int k = 0; k < pack_factor; k++) { for (int k = 0; k < tn * pack_factor; k++) {
result[k] = simd_sum(result[k]); result[k] = simd_sum(result[k]);
} }
// Store the result // Store the result
if (simd_lid == 0) { if (simd_lid == 0) {
#pragma clang loop unroll(full) #pragma clang loop unroll(full)
for (int k = 0; k < pack_factor; k++) { for (int k = 0; k < tn * pack_factor; k++) {
y[k] = static_cast<T>(result[k]); y[k] = static_cast<T>(result[k]);
} }
} }

View File

@ -137,10 +137,10 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
int bo = 8; int bo = 64;
int bd = 32; int bd = 32;
MTL::Size group_dims = MTL::Size(bd, bo, 1); MTL::Size group_dims = MTL::Size(bd, 2, 1);
MTL::Size grid_dims = MTL::Size((O + bo - 1) / bo, B, 1); MTL::Size grid_dims = MTL::Size(O / bo, B, 1);
compute_encoder.set_input_array(x, 0); compute_encoder.set_input_array(x, 0);
compute_encoder.set_input_array(w, 1); compute_encoder.set_input_array(w, 1);
@ -393,10 +393,10 @@ void BlockSparseQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname.str());
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
int bo = 8; int bo = 64;
int bd = 32; int bd = 32;
MTL::Size group_dims = MTL::Size(bd, bo, 1); MTL::Size group_dims = MTL::Size(bd, 2, 1);
MTL::Size grid_dims = MTL::Size((O + bo - 1) / bo, B, N); MTL::Size grid_dims = MTL::Size(O / bo, B, N);
compute_encoder.set_input_array(x, 0); compute_encoder.set_input_array(x, 0);
compute_encoder.set_input_array(w, 1); compute_encoder.set_input_array(w, 1);