From aca758463562db8171f58f27d32fc6567575015d Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Wed, 27 Mar 2024 22:18:35 -0700 Subject: [PATCH] Fix OOB read in qmv when non-divisible by blocksize (#917) --- mlx/backend/metal/kernels/quantized.metal | 119 ++++++++++++++++++++-- 1 file changed, 109 insertions(+), 10 deletions(-) diff --git a/mlx/backend/metal/kernels/quantized.metal b/mlx/backend/metal/kernels/quantized.metal index f5c21364d..ce7c05ec0 100644 --- a/mlx/backend/metal/kernels/quantized.metal +++ b/mlx/backend/metal/kernels/quantized.metal @@ -15,14 +15,6 @@ using namespace metal; MLX_MTL_CONST int SIMD_SIZE = 32; -template struct AccT { - typedef T acc_t; -}; - -template <> struct AccT { - typedef float acc_t; -}; - template inline U load_vector(const device T *x, thread U *x_thread) { @@ -60,6 +52,51 @@ inline U load_vector(const device T *x, thread U *x_thread) { return sum; } +template +inline U load_vector_safe(const device T *x, thread U *x_thread, int N) { + static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}"); + + U sum = 0; + + if (bits == 2) { + for (int i = 0; i < N; i += 4) { + sum += x[i] + x[i+1] + x[i+2] + x[i+3]; + x_thread[i] = x[i]; + x_thread[i+1] = x[i+1] / 4.0f; + x_thread[i+2] = x[i+2] / 16.0f; + x_thread[i+3] = x[i+3] / 64.0f; + } + for (int i=N; i inline U qdot(const device uint8_t* w, const thread U *x_thread, U scale, U bias, U sum) { static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}"); @@ -96,6 +133,42 @@ inline U qdot(const device uint8_t* w, const thread U *x_thread, U scale, U bias return scale * accum + sum * bias; } +template +inline U qdot_safe(const device uint8_t* w, const thread U *x_thread, U scale, U bias, U sum, int N) { + static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}"); + + U accum = 0; + + if (bits == 2) { + for (int i = 0; i < (N / 4); i++) { + accum += ( + x_thread[4*i] * (w[i] & 0x03) + + x_thread[4*i+1] * (w[i] & 0x0c) + + x_thread[4*i+2] * (w[i] & 0x30) + + x_thread[4*i+3] * (w[i] & 0xc0)); + } + } + + else if (bits == 4) { + const device uint16_t* ws = (const device uint16_t*)w; + for (int i = 0; i < (N / 4); i++) { + accum += ( + x_thread[4*i] * (ws[i] & 0x000f) + + x_thread[4*i+1] * (ws[i] & 0x00f0) + + x_thread[4*i+2] * (ws[i] & 0x0f00) + + x_thread[4*i+3] * (ws[i] & 0xf000)); + } + } + + else if (bits == 8) { + for (int i = 0; i < N; i++) { + accum += x_thread[i] * w[i]; + } + } + + return scale * accum + sum * bias; +} + template inline void qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) { static_assert(bits == 2 || bits == 4 || bits == 8, "Template undefined for bits not in {2, 4, 8}"); @@ -236,7 +309,8 @@ template x += tid.z * in_vec_size + simd_lid * values_per_thread; y += tid.z * out_vec_size + out_row; - for (int k = 0; k < in_vec_size; k += block_size) { + int k = 0; + for (; k < in_vec_size-block_size; k += block_size) { U sum = load_vector(x, x_thread); for (int row = 0; out_row + row < out_vec_size; row++) { @@ -254,6 +328,18 @@ template biases += block_size / group_size; x += block_size; } + const int remaining = clamp(static_cast(in_vec_size - k - simd_lid * values_per_thread), 0, values_per_thread); + U sum = load_vector_safe(x, x_thread, remaining); + + for (int row = 0; out_row + row < out_vec_size; row++) { + const device uint8_t* wl = (const device uint8_t *)(w + row * in_vec_size_w); + const device T* sl = scales + row * in_vec_size_g; + const device T* bl = biases + row * in_vec_size_g; + + U s = sl[0]; + U b = bl[0]; + result[row] += qdot(wl, x_thread, s, b, sum); + } for (int row = 0; out_row + row < out_vec_size; row++) { result[row] = simd_sum(result[row]); @@ -271,7 +357,8 @@ template x += tid.z * in_vec_size + simd_lid * values_per_thread; y += tid.z * out_vec_size + used_out_row; - for (int k = 0; k < in_vec_size; k += block_size) { + int k = 0; + for (; k < in_vec_size-block_size; k += block_size) { U sum = load_vector(x, x_thread); for (int row = 0; row < results_per_simdgroup; row++) { @@ -289,6 +376,18 @@ template biases += block_size / group_size; x += block_size; } + const int remaining = clamp(static_cast(in_vec_size - k - simd_lid * values_per_thread), 0, values_per_thread); + U sum = load_vector_safe(x, x_thread, remaining); + + for (int row = 0; row < results_per_simdgroup; row++) { + const device uint8_t* wl = (const device uint8_t *)(w + row * in_vec_size_w); + const device T* sl = scales + row * in_vec_size_g; + const device T* bl = biases + row * in_vec_size_g; + + U s = sl[0]; + U b = bl[0]; + result[row] += qdot_safe(wl, x_thread, s, b, sum, remaining); + } for (int row = 0; row < results_per_simdgroup; row++) { result[row] = simd_sum(result[row]);