mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-19 23:51:14 +08:00
Fix OOB read in qmv when non-divisible by blocksize (#917)
This commit is contained in:
parent
d611251502
commit
aca7584635
@ -15,14 +15,6 @@ using namespace metal;
|
||||
|
||||
MLX_MTL_CONST int SIMD_SIZE = 32;
|
||||
|
||||
template <typename T> struct AccT {
|
||||
typedef T acc_t;
|
||||
};
|
||||
|
||||
template <> struct AccT<bfloat16_t> {
|
||||
typedef float acc_t;
|
||||
};
|
||||
|
||||
|
||||
template <typename T, typename U, int values_per_thread, int bits>
|
||||
inline U load_vector(const device T *x, thread U *x_thread) {
|
||||
@ -60,6 +52,51 @@ inline U load_vector(const device T *x, thread U *x_thread) {
|
||||
return sum;
|
||||
}
|
||||
|
||||
template <typename T, typename U, int values_per_thread, int bits>
|
||||
inline U load_vector_safe(const device T *x, thread U *x_thread, int N) {
|
||||
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
|
||||
|
||||
U sum = 0;
|
||||
|
||||
if (bits == 2) {
|
||||
for (int i = 0; i < N; i += 4) {
|
||||
sum += x[i] + x[i+1] + x[i+2] + x[i+3];
|
||||
x_thread[i] = x[i];
|
||||
x_thread[i+1] = x[i+1] / 4.0f;
|
||||
x_thread[i+2] = x[i+2] / 16.0f;
|
||||
x_thread[i+3] = x[i+3] / 64.0f;
|
||||
}
|
||||
for (int i=N; i<values_per_thread; i++) {
|
||||
x_thread[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 4) {
|
||||
for (int i = 0; i < N; i += 4) {
|
||||
sum += x[i] + x[i+1] + x[i+2] + x[i+3];
|
||||
x_thread[i] = x[i];
|
||||
x_thread[i+1] = x[i+1] / 16.0f;
|
||||
x_thread[i+2] = x[i+2] / 256.0f;
|
||||
x_thread[i+3] = x[i+3] / 4096.0f;
|
||||
}
|
||||
for (int i=N; i<values_per_thread; i++) {
|
||||
x_thread[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 8) {
|
||||
for (int i = 0; i < N; i++) {
|
||||
sum += x[i];
|
||||
x_thread[i] = x[i];
|
||||
}
|
||||
for (int i=N; i<values_per_thread; i++) {
|
||||
x_thread[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
return sum;
|
||||
}
|
||||
|
||||
template <typename U, int values_per_thread, int bits>
|
||||
inline U qdot(const device uint8_t* w, const thread U *x_thread, U scale, U bias, U sum) {
|
||||
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
|
||||
@ -96,6 +133,42 @@ inline U qdot(const device uint8_t* w, const thread U *x_thread, U scale, U bias
|
||||
return scale * accum + sum * bias;
|
||||
}
|
||||
|
||||
template <typename U, int values_per_thread, int bits>
|
||||
inline U qdot_safe(const device uint8_t* w, const thread U *x_thread, U scale, U bias, U sum, int N) {
|
||||
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
|
||||
|
||||
U accum = 0;
|
||||
|
||||
if (bits == 2) {
|
||||
for (int i = 0; i < (N / 4); i++) {
|
||||
accum += (
|
||||
x_thread[4*i] * (w[i] & 0x03)
|
||||
+ x_thread[4*i+1] * (w[i] & 0x0c)
|
||||
+ x_thread[4*i+2] * (w[i] & 0x30)
|
||||
+ x_thread[4*i+3] * (w[i] & 0xc0));
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 4) {
|
||||
const device uint16_t* ws = (const device uint16_t*)w;
|
||||
for (int i = 0; i < (N / 4); i++) {
|
||||
accum += (
|
||||
x_thread[4*i] * (ws[i] & 0x000f)
|
||||
+ x_thread[4*i+1] * (ws[i] & 0x00f0)
|
||||
+ x_thread[4*i+2] * (ws[i] & 0x0f00)
|
||||
+ x_thread[4*i+3] * (ws[i] & 0xf000));
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 8) {
|
||||
for (int i = 0; i < N; i++) {
|
||||
accum += x_thread[i] * w[i];
|
||||
}
|
||||
}
|
||||
|
||||
return scale * accum + sum * bias;
|
||||
}
|
||||
|
||||
template <typename U, int values_per_thread, int bits>
|
||||
inline void qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
|
||||
static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}");
|
||||
@ -236,7 +309,8 @@ template <typename T, const int group_size, const int bits>
|
||||
x += tid.z * in_vec_size + simd_lid * values_per_thread;
|
||||
y += tid.z * out_vec_size + out_row;
|
||||
|
||||
for (int k = 0; k < in_vec_size; k += block_size) {
|
||||
int k = 0;
|
||||
for (; k < in_vec_size-block_size; k += block_size) {
|
||||
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
||||
|
||||
for (int row = 0; out_row + row < out_vec_size; row++) {
|
||||
@ -254,6 +328,18 @@ template <typename T, const int group_size, const int bits>
|
||||
biases += block_size / group_size;
|
||||
x += block_size;
|
||||
}
|
||||
const int remaining = clamp(static_cast<int>(in_vec_size - k - simd_lid * values_per_thread), 0, values_per_thread);
|
||||
U sum = load_vector_safe<T, U, values_per_thread, bits>(x, x_thread, remaining);
|
||||
|
||||
for (int row = 0; out_row + row < out_vec_size; row++) {
|
||||
const device uint8_t* wl = (const device uint8_t *)(w + row * in_vec_size_w);
|
||||
const device T* sl = scales + row * in_vec_size_g;
|
||||
const device T* bl = biases + row * in_vec_size_g;
|
||||
|
||||
U s = sl[0];
|
||||
U b = bl[0];
|
||||
result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
||||
}
|
||||
|
||||
for (int row = 0; out_row + row < out_vec_size; row++) {
|
||||
result[row] = simd_sum(result[row]);
|
||||
@ -271,7 +357,8 @@ template <typename T, const int group_size, const int bits>
|
||||
x += tid.z * in_vec_size + simd_lid * values_per_thread;
|
||||
y += tid.z * out_vec_size + used_out_row;
|
||||
|
||||
for (int k = 0; k < in_vec_size; k += block_size) {
|
||||
int k = 0;
|
||||
for (; k < in_vec_size-block_size; k += block_size) {
|
||||
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
||||
|
||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||
@ -289,6 +376,18 @@ template <typename T, const int group_size, const int bits>
|
||||
biases += block_size / group_size;
|
||||
x += block_size;
|
||||
}
|
||||
const int remaining = clamp(static_cast<int>(in_vec_size - k - simd_lid * values_per_thread), 0, values_per_thread);
|
||||
U sum = load_vector_safe<T, U, values_per_thread, bits>(x, x_thread, remaining);
|
||||
|
||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||
const device uint8_t* wl = (const device uint8_t *)(w + row * in_vec_size_w);
|
||||
const device T* sl = scales + row * in_vec_size_g;
|
||||
const device T* bl = biases + row * in_vec_size_g;
|
||||
|
||||
U s = sl[0];
|
||||
U b = bl[0];
|
||||
result[row] += qdot_safe<U, values_per_thread, bits>(wl, x_thread, s, b, sum, remaining);
|
||||
}
|
||||
|
||||
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||
result[row] = simd_sum(result[row]);
|
||||
|
Loading…
Reference in New Issue
Block a user