mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-24 02:41:19 +08:00
Improve qvm speed (#1140)
This commit is contained in:
parent
7e5674d8be
commit
da83f899bb
@ -601,14 +601,18 @@ METAL_FUNC void qvm_impl(
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
constexpr int num_simdgroups = 8;
|
||||
constexpr int num_simdgroups = 2;
|
||||
constexpr int pack_factor = 32 / bits;
|
||||
constexpr int tn = 32 / pack_factor;
|
||||
constexpr int blocksize = SIMD_SIZE;
|
||||
|
||||
typedef float U;
|
||||
typedef struct {
|
||||
uint32_t wi[tn];
|
||||
} vec_w;
|
||||
|
||||
thread uint32_t w_local;
|
||||
thread U result[pack_factor] = {0};
|
||||
thread vec_w w_local;
|
||||
thread U result[tn * pack_factor] = {0};
|
||||
thread U scale = 1;
|
||||
thread U bias = 0;
|
||||
thread U x_local = 0;
|
||||
@ -616,11 +620,12 @@ METAL_FUNC void qvm_impl(
|
||||
// Adjust positions
|
||||
const int out_vec_size_w = out_vec_size / pack_factor;
|
||||
const int out_vec_size_g = out_vec_size / group_size;
|
||||
int out_col = tid.x * (num_simdgroups * pack_factor) + simd_gid * pack_factor;
|
||||
w += out_col / pack_factor;
|
||||
scales += out_col / group_size;
|
||||
biases += out_col / group_size;
|
||||
x += tid.y * in_vec_size;
|
||||
int out_col =
|
||||
tid.x * (num_simdgroups * pack_factor * tn) + simd_gid * pack_factor * tn;
|
||||
w += out_col / pack_factor + simd_lid * out_vec_size_w;
|
||||
scales += out_col / group_size + simd_lid * out_vec_size_g;
|
||||
biases += out_col / group_size + simd_lid * out_vec_size_g;
|
||||
x += tid.y * in_vec_size + simd_lid;
|
||||
y += tid.y * out_vec_size + out_col;
|
||||
|
||||
if (out_col >= out_vec_size) {
|
||||
@ -628,40 +633,61 @@ METAL_FUNC void qvm_impl(
|
||||
}
|
||||
|
||||
// Loop over in_vec in blocks of blocksize
|
||||
int i = 0;
|
||||
for (; i + blocksize <= in_vec_size; i += blocksize) {
|
||||
x_local = x[i + simd_lid];
|
||||
scale = scales[(i + simd_lid) * out_vec_size_g];
|
||||
bias = biases[(i + simd_lid) * out_vec_size_g];
|
||||
w_local = w[(i + simd_lid) * out_vec_size_w];
|
||||
int remaining = in_vec_size % blocksize;
|
||||
if (remaining == 0) {
|
||||
for (int i = 0; i < in_vec_size; i += blocksize) {
|
||||
x_local = *x;
|
||||
scale = *scales;
|
||||
bias = *biases;
|
||||
w_local = *((device vec_w*)w);
|
||||
|
||||
qouter<U, pack_factor, bits>(
|
||||
qouter<U, tn * pack_factor, bits>(
|
||||
(thread uint8_t*)&w_local, x_local, scale, bias, result);
|
||||
|
||||
x += blocksize;
|
||||
scales += blocksize * out_vec_size_g;
|
||||
biases += blocksize * out_vec_size_g;
|
||||
w += blocksize * out_vec_size_w;
|
||||
}
|
||||
} else {
|
||||
for (int i = blocksize; i < in_vec_size; i += blocksize) {
|
||||
x_local = *x;
|
||||
scale = *scales;
|
||||
bias = *biases;
|
||||
w_local = *((device vec_w*)w);
|
||||
|
||||
qouter<U, tn * pack_factor, bits>(
|
||||
(thread uint8_t*)&w_local, x_local, scale, bias, result);
|
||||
|
||||
x += blocksize;
|
||||
scales += blocksize * out_vec_size_g;
|
||||
biases += blocksize * out_vec_size_g;
|
||||
w += blocksize * out_vec_size_w;
|
||||
}
|
||||
if (static_cast<int>(simd_lid) < remaining) {
|
||||
x_local = *x;
|
||||
scale = *scales;
|
||||
bias = *biases;
|
||||
w_local = *((device vec_w*)w);
|
||||
} else {
|
||||
x_local = 0;
|
||||
scale = 0;
|
||||
bias = 0;
|
||||
}
|
||||
qouter<U, tn * pack_factor, bits>(
|
||||
(thread uint8_t*)&w_local, x_local, scale, bias, result);
|
||||
}
|
||||
if (static_cast<int>(i + simd_lid) < in_vec_size) {
|
||||
x_local = x[i + simd_lid];
|
||||
scale = scales[(i + simd_lid) * out_vec_size_g];
|
||||
bias = biases[(i + simd_lid) * out_vec_size_g];
|
||||
w_local = w[(i + simd_lid) * out_vec_size_w];
|
||||
} else {
|
||||
x_local = 0;
|
||||
scale = 0;
|
||||
bias = 0;
|
||||
w_local = 0;
|
||||
}
|
||||
qouter<U, pack_factor, bits>(
|
||||
(thread uint8_t*)&w_local, x_local, scale, bias, result);
|
||||
|
||||
// Accumulate in the simdgroup
|
||||
#pragma clang loop unroll(full)
|
||||
for (int k = 0; k < pack_factor; k++) {
|
||||
for (int k = 0; k < tn * pack_factor; k++) {
|
||||
result[k] = simd_sum(result[k]);
|
||||
}
|
||||
|
||||
// Store the result
|
||||
if (simd_lid == 0) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int k = 0; k < pack_factor; k++) {
|
||||
for (int k = 0; k < tn * pack_factor; k++) {
|
||||
y[k] = static_cast<T>(result[k]);
|
||||
}
|
||||
}
|
||||
|
@ -137,10 +137,10 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int bo = 8;
|
||||
int bo = 64;
|
||||
int bd = 32;
|
||||
MTL::Size group_dims = MTL::Size(bd, bo, 1);
|
||||
MTL::Size grid_dims = MTL::Size((O + bo - 1) / bo, B, 1);
|
||||
MTL::Size group_dims = MTL::Size(bd, 2, 1);
|
||||
MTL::Size grid_dims = MTL::Size(O / bo, B, 1);
|
||||
|
||||
compute_encoder.set_input_array(x, 0);
|
||||
compute_encoder.set_input_array(w, 1);
|
||||
@ -393,10 +393,10 @@ void BlockSparseQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int bo = 8;
|
||||
int bo = 64;
|
||||
int bd = 32;
|
||||
MTL::Size group_dims = MTL::Size(bd, bo, 1);
|
||||
MTL::Size grid_dims = MTL::Size((O + bo - 1) / bo, B, N);
|
||||
MTL::Size group_dims = MTL::Size(bd, 2, 1);
|
||||
MTL::Size grid_dims = MTL::Size(O / bo, B, N);
|
||||
|
||||
compute_encoder.set_input_array(x, 0);
|
||||
compute_encoder.set_input_array(w, 1);
|
||||
|
Loading…
Reference in New Issue
Block a user