mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-24 02:41:19 +08:00
Improve qvm speed (#1140)
This commit is contained in:
parent
7e5674d8be
commit
da83f899bb
@ -601,14 +601,18 @@ METAL_FUNC void qvm_impl(
|
|||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
constexpr int num_simdgroups = 8;
|
constexpr int num_simdgroups = 2;
|
||||||
constexpr int pack_factor = 32 / bits;
|
constexpr int pack_factor = 32 / bits;
|
||||||
|
constexpr int tn = 32 / pack_factor;
|
||||||
constexpr int blocksize = SIMD_SIZE;
|
constexpr int blocksize = SIMD_SIZE;
|
||||||
|
|
||||||
typedef float U;
|
typedef float U;
|
||||||
|
typedef struct {
|
||||||
|
uint32_t wi[tn];
|
||||||
|
} vec_w;
|
||||||
|
|
||||||
thread uint32_t w_local;
|
thread vec_w w_local;
|
||||||
thread U result[pack_factor] = {0};
|
thread U result[tn * pack_factor] = {0};
|
||||||
thread U scale = 1;
|
thread U scale = 1;
|
||||||
thread U bias = 0;
|
thread U bias = 0;
|
||||||
thread U x_local = 0;
|
thread U x_local = 0;
|
||||||
@ -616,11 +620,12 @@ METAL_FUNC void qvm_impl(
|
|||||||
// Adjust positions
|
// Adjust positions
|
||||||
const int out_vec_size_w = out_vec_size / pack_factor;
|
const int out_vec_size_w = out_vec_size / pack_factor;
|
||||||
const int out_vec_size_g = out_vec_size / group_size;
|
const int out_vec_size_g = out_vec_size / group_size;
|
||||||
int out_col = tid.x * (num_simdgroups * pack_factor) + simd_gid * pack_factor;
|
int out_col =
|
||||||
w += out_col / pack_factor;
|
tid.x * (num_simdgroups * pack_factor * tn) + simd_gid * pack_factor * tn;
|
||||||
scales += out_col / group_size;
|
w += out_col / pack_factor + simd_lid * out_vec_size_w;
|
||||||
biases += out_col / group_size;
|
scales += out_col / group_size + simd_lid * out_vec_size_g;
|
||||||
x += tid.y * in_vec_size;
|
biases += out_col / group_size + simd_lid * out_vec_size_g;
|
||||||
|
x += tid.y * in_vec_size + simd_lid;
|
||||||
y += tid.y * out_vec_size + out_col;
|
y += tid.y * out_vec_size + out_col;
|
||||||
|
|
||||||
if (out_col >= out_vec_size) {
|
if (out_col >= out_vec_size) {
|
||||||
@ -628,40 +633,61 @@ METAL_FUNC void qvm_impl(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Loop over in_vec in blocks of blocksize
|
// Loop over in_vec in blocks of blocksize
|
||||||
int i = 0;
|
int remaining = in_vec_size % blocksize;
|
||||||
for (; i + blocksize <= in_vec_size; i += blocksize) {
|
if (remaining == 0) {
|
||||||
x_local = x[i + simd_lid];
|
for (int i = 0; i < in_vec_size; i += blocksize) {
|
||||||
scale = scales[(i + simd_lid) * out_vec_size_g];
|
x_local = *x;
|
||||||
bias = biases[(i + simd_lid) * out_vec_size_g];
|
scale = *scales;
|
||||||
w_local = w[(i + simd_lid) * out_vec_size_w];
|
bias = *biases;
|
||||||
|
w_local = *((device vec_w*)w);
|
||||||
|
|
||||||
qouter<U, pack_factor, bits>(
|
qouter<U, tn * pack_factor, bits>(
|
||||||
|
(thread uint8_t*)&w_local, x_local, scale, bias, result);
|
||||||
|
|
||||||
|
x += blocksize;
|
||||||
|
scales += blocksize * out_vec_size_g;
|
||||||
|
biases += blocksize * out_vec_size_g;
|
||||||
|
w += blocksize * out_vec_size_w;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = blocksize; i < in_vec_size; i += blocksize) {
|
||||||
|
x_local = *x;
|
||||||
|
scale = *scales;
|
||||||
|
bias = *biases;
|
||||||
|
w_local = *((device vec_w*)w);
|
||||||
|
|
||||||
|
qouter<U, tn * pack_factor, bits>(
|
||||||
|
(thread uint8_t*)&w_local, x_local, scale, bias, result);
|
||||||
|
|
||||||
|
x += blocksize;
|
||||||
|
scales += blocksize * out_vec_size_g;
|
||||||
|
biases += blocksize * out_vec_size_g;
|
||||||
|
w += blocksize * out_vec_size_w;
|
||||||
|
}
|
||||||
|
if (static_cast<int>(simd_lid) < remaining) {
|
||||||
|
x_local = *x;
|
||||||
|
scale = *scales;
|
||||||
|
bias = *biases;
|
||||||
|
w_local = *((device vec_w*)w);
|
||||||
|
} else {
|
||||||
|
x_local = 0;
|
||||||
|
scale = 0;
|
||||||
|
bias = 0;
|
||||||
|
}
|
||||||
|
qouter<U, tn * pack_factor, bits>(
|
||||||
(thread uint8_t*)&w_local, x_local, scale, bias, result);
|
(thread uint8_t*)&w_local, x_local, scale, bias, result);
|
||||||
}
|
}
|
||||||
if (static_cast<int>(i + simd_lid) < in_vec_size) {
|
|
||||||
x_local = x[i + simd_lid];
|
|
||||||
scale = scales[(i + simd_lid) * out_vec_size_g];
|
|
||||||
bias = biases[(i + simd_lid) * out_vec_size_g];
|
|
||||||
w_local = w[(i + simd_lid) * out_vec_size_w];
|
|
||||||
} else {
|
|
||||||
x_local = 0;
|
|
||||||
scale = 0;
|
|
||||||
bias = 0;
|
|
||||||
w_local = 0;
|
|
||||||
}
|
|
||||||
qouter<U, pack_factor, bits>(
|
|
||||||
(thread uint8_t*)&w_local, x_local, scale, bias, result);
|
|
||||||
|
|
||||||
// Accumulate in the simdgroup
|
// Accumulate in the simdgroup
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for (int k = 0; k < pack_factor; k++) {
|
for (int k = 0; k < tn * pack_factor; k++) {
|
||||||
result[k] = simd_sum(result[k]);
|
result[k] = simd_sum(result[k]);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store the result
|
// Store the result
|
||||||
if (simd_lid == 0) {
|
if (simd_lid == 0) {
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for (int k = 0; k < pack_factor; k++) {
|
for (int k = 0; k < tn * pack_factor; k++) {
|
||||||
y[k] = static_cast<T>(result[k]);
|
y[k] = static_cast<T>(result[k]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -137,10 +137,10 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto kernel = d.get_kernel(kname.str());
|
auto kernel = d.get_kernel(kname.str());
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
int bo = 8;
|
int bo = 64;
|
||||||
int bd = 32;
|
int bd = 32;
|
||||||
MTL::Size group_dims = MTL::Size(bd, bo, 1);
|
MTL::Size group_dims = MTL::Size(bd, 2, 1);
|
||||||
MTL::Size grid_dims = MTL::Size((O + bo - 1) / bo, B, 1);
|
MTL::Size grid_dims = MTL::Size(O / bo, B, 1);
|
||||||
|
|
||||||
compute_encoder.set_input_array(x, 0);
|
compute_encoder.set_input_array(x, 0);
|
||||||
compute_encoder.set_input_array(w, 1);
|
compute_encoder.set_input_array(w, 1);
|
||||||
@ -393,10 +393,10 @@ void BlockSparseQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto kernel = d.get_kernel(kname.str());
|
auto kernel = d.get_kernel(kname.str());
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
int bo = 8;
|
int bo = 64;
|
||||||
int bd = 32;
|
int bd = 32;
|
||||||
MTL::Size group_dims = MTL::Size(bd, bo, 1);
|
MTL::Size group_dims = MTL::Size(bd, 2, 1);
|
||||||
MTL::Size grid_dims = MTL::Size((O + bo - 1) / bo, B, N);
|
MTL::Size grid_dims = MTL::Size(O / bo, B, N);
|
||||||
|
|
||||||
compute_encoder.set_input_array(x, 0);
|
compute_encoder.set_input_array(x, 0);
|
||||||
compute_encoder.set_input_array(w, 1);
|
compute_encoder.set_input_array(w, 1);
|
||||||
|
Loading…
Reference in New Issue
Block a user