mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-24 02:24:49 +08:00
Fix nan and improve speed for qvm (#903)
This commit is contained in:
parent
a3ee03da01
commit
9948eddf11
@ -96,6 +96,38 @@ inline U qdot(const device uint8_t* w, const thread U *x_thread, U scale, U bias
|
||||
return scale * accum + sum * bias;
|
||||
}
|
||||
|
||||
template <typename U, int values_per_thread, int bits>
|
||||
inline void qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
|
||||
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
|
||||
|
||||
if (bits == 2) {
|
||||
U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f};
|
||||
for (int i = 0; i < (values_per_thread / 4); i++) {
|
||||
result[4*i] += x * (s[0] * (w[i] & 0x03) + bias);
|
||||
result[4*i+1] += x * (s[1] * (w[i] & 0x0c) + bias);
|
||||
result[4*i+2] += x * (s[2] * (w[i] & 0x30) + bias);
|
||||
result[4*i+3] += x * (s[3] * (w[i] & 0xc0) + bias);
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 4) {
|
||||
const thread uint16_t* ws = (const thread uint16_t*)w;
|
||||
U s[4] = {scale, scale / 16.0f, scale / 256.0f, scale / 4096.0f};
|
||||
for (int i = 0; i < (values_per_thread / 4); i++) {
|
||||
result[4*i] += x * (s[0] * (ws[i] & 0x000f) + bias);
|
||||
result[4*i+1] += x * (s[1] * (ws[i] & 0x00f0) + bias);
|
||||
result[4*i+2] += x * (s[2] * (ws[i] & 0x0f00) + bias);
|
||||
result[4*i+3] += x * (s[3] * (ws[i] & 0xf000) + bias);
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 8) {
|
||||
for (int i = 0; i < values_per_thread; i++) {
|
||||
result[i] += x * (scale * w[i] + bias);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits, int packs_per_thread>
|
||||
[[kernel]] void qmv_fast(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
@ -268,7 +300,7 @@ template <typename T, const int group_size, const int bits>
|
||||
}
|
||||
|
||||
|
||||
template <typename T, const int BM, const int BN, const int group_size, const int bits>
|
||||
template <typename T, const int group_size, const int bits>
|
||||
[[kernel]] void qvm(
|
||||
const device T* x [[buffer(0)]],
|
||||
const device uint32_t* w [[buffer(1)]],
|
||||
@ -278,39 +310,28 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
||||
const constant int& in_vec_size [[buffer(5)]],
|
||||
const constant int& out_vec_size [[buffer(6)]],
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint lid [[thread_index_in_threadgroup]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
|
||||
static_assert(BM == SIMD_SIZE, "qvm expects BM to be equal to SIMD_SIZE");
|
||||
static_assert(BN == BM, "qvm expects a block size of 32x32");
|
||||
constexpr int num_simdgroups = 8;
|
||||
constexpr int pack_factor = 32 / bits;
|
||||
constexpr int blocksize = SIMD_SIZE;
|
||||
|
||||
(void)lid;
|
||||
|
||||
constexpr int bitmask = (1 << bits) - 1;
|
||||
constexpr int el_per_int = 32 / bits;
|
||||
constexpr int colgroup = BN * el_per_int;
|
||||
constexpr int groups_per_block = colgroup / group_size;
|
||||
|
||||
typedef typename AccT<T>::acc_t U;
|
||||
threadgroup U scales_block[BM * groups_per_block];
|
||||
threadgroup U biases_block[BM * groups_per_block];
|
||||
threadgroup U x_block[BM];
|
||||
typedef float U;
|
||||
|
||||
thread uint32_t w_local;
|
||||
thread U result[el_per_int] = {0};
|
||||
thread U result[pack_factor] = {0};
|
||||
thread U scale = 1;
|
||||
thread U bias = 0;
|
||||
thread U x_local = 0;
|
||||
|
||||
// Adjust positions
|
||||
const int out_vec_size_w = out_vec_size / el_per_int;
|
||||
const int out_vec_size_w = out_vec_size / pack_factor;
|
||||
const int out_vec_size_g = out_vec_size / group_size;
|
||||
int out_col_start = tid.y * (BN * el_per_int);
|
||||
int out_col = out_col_start + simd_gid * el_per_int;
|
||||
w += out_col / el_per_int;
|
||||
scales += out_col_start / group_size;
|
||||
biases += out_col_start / group_size;
|
||||
int out_col = tid.y * (num_simdgroups * pack_factor) + simd_gid * pack_factor;
|
||||
w += out_col / pack_factor;
|
||||
scales += out_col / group_size;
|
||||
biases += out_col / group_size;
|
||||
x += tid.z * in_vec_size;
|
||||
y += tid.z * out_vec_size + out_col;
|
||||
|
||||
@ -318,53 +339,39 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
||||
return;
|
||||
}
|
||||
|
||||
// Loop over in_vec in blocks of colgroup
|
||||
for (int i=0; i<in_vec_size; i+=BM) {
|
||||
int offset_lid = simd_lid + i;
|
||||
int offset_gid = simd_gid + i;
|
||||
bool thread_in_bounds = offset_lid < in_vec_size;
|
||||
bool group_in_bounds = offset_gid < in_vec_size;
|
||||
// Loop over in_vec in blocks of blocksize
|
||||
int i = 0;
|
||||
for (; i + blocksize <= in_vec_size; i += blocksize) {
|
||||
x_local = x[i + simd_lid];
|
||||
scale = scales[(i + simd_lid) * out_vec_size_g];
|
||||
bias = biases[(i + simd_lid) * out_vec_size_g];
|
||||
w_local = w[(i + simd_lid) * out_vec_size_w];
|
||||
|
||||
// Load the vec to shared memory
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_gid == 0) {
|
||||
x_block[simd_lid] = (thread_in_bounds) ? x[offset_lid] : 0;
|
||||
}
|
||||
|
||||
// Load the scales and biases to shared memory
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_lid < groups_per_block && group_in_bounds) {
|
||||
scales_block[simd_gid * groups_per_block + simd_lid] = scales[offset_gid * out_vec_size_g + simd_lid];
|
||||
biases_block[simd_gid * groups_per_block + simd_lid] = biases[offset_gid * out_vec_size_g + simd_lid];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Load in_vec, scale, bias to registers
|
||||
x_local = x_block[simd_lid];
|
||||
scale = scales_block[simd_lid * groups_per_block + (simd_gid * el_per_int) / group_size];
|
||||
bias = biases_block[simd_lid * groups_per_block + (simd_gid * el_per_int) / group_size];
|
||||
|
||||
// Load the matrix elements
|
||||
w_local = (thread_in_bounds) ? w[offset_lid * out_vec_size_w] : 0;
|
||||
|
||||
// Do all the work.
|
||||
#pragma clang loop unroll(full)
|
||||
for (int k=0; k<el_per_int; k++) {
|
||||
result[k] += (scale * static_cast<U>(w_local & bitmask) + bias) * x_local;
|
||||
w_local >>= bits;
|
||||
}
|
||||
qouter<U, pack_factor, bits>((thread uint8_t *)&w_local, x_local, scale, bias, result);
|
||||
}
|
||||
if (static_cast<int>(i + simd_lid) < in_vec_size) {
|
||||
x_local = x[i + simd_lid];
|
||||
scale = scales[(i + simd_lid) * out_vec_size_g];
|
||||
bias = biases[(i + simd_lid) * out_vec_size_g];
|
||||
w_local = w[(i + simd_lid) * out_vec_size_w];
|
||||
} else {
|
||||
x_local = 0;
|
||||
scale = 0;
|
||||
bias = 0;
|
||||
w_local = 0;
|
||||
}
|
||||
qouter<U, pack_factor, bits>((thread uint8_t *)&w_local, x_local, scale, bias, result);
|
||||
|
||||
// Accumulate in the simdgroup
|
||||
#pragma clang loop unroll(full)
|
||||
for (int k=0; k<el_per_int; k++) {
|
||||
for (int k=0; k<pack_factor; k++) {
|
||||
result[k] = simd_sum(result[k]);
|
||||
}
|
||||
|
||||
// Store the result
|
||||
if (simd_lid == 0) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int k=0; k<el_per_int; k++) {
|
||||
for (int k=0; k<pack_factor; k++) {
|
||||
y[k] = static_cast<T>(result[k]);
|
||||
}
|
||||
}
|
||||
@ -738,7 +745,7 @@ instantiate_qmv_types( 32, 8)
|
||||
|
||||
#define instantiate_qvm(name, itype, group_size, bits) \
|
||||
template [[host_name("qvm_" #name "_gs_" #group_size "_b_" #bits)]] \
|
||||
[[kernel]] void qvm<itype, 32, 32, group_size, bits>( \
|
||||
[[kernel]] void qvm<itype, group_size, bits>( \
|
||||
const device itype* x [[buffer(0)]], \
|
||||
const device uint32_t* w [[buffer(1)]], \
|
||||
const device itype* scales [[buffer(2)]], \
|
||||
@ -747,7 +754,6 @@ instantiate_qmv_types( 32, 8)
|
||||
const constant int& in_vec_size [[buffer(5)]], \
|
||||
const constant int& out_vec_size [[buffer(6)]], \
|
||||
uint3 tid [[threadgroup_position_in_grid]], \
|
||||
uint lid [[thread_index_in_threadgroup]], \
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
|
@ -137,7 +137,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int bo = std::min(32, O);
|
||||
int bo = 8;
|
||||
int bd = 32;
|
||||
MTL::Size group_dims = MTL::Size(bd, bo, 1);
|
||||
MTL::Size grid_dims = MTL::Size(1, (O + bo - 1) / bo, B);
|
||||
|
Loading…
Reference in New Issue
Block a user