Fix OOB access in qmv (#1577)

* fix oob access in qmv

* skip more
This commit is contained in:
Alex Barron 2024-11-08 15:41:30 -08:00 committed by GitHub
parent c1fe1ef081
commit 111fefd5e9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -619,21 +619,22 @@ METAL_FUNC void qmv_impl(
static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
0,
values_per_thread);
U sum =
load_vector_safe<T, U, values_per_thread, bits>(x, x_thread, remaining);
if (remaining > 0) {
U sum = load_vector_safe<T, U, values_per_thread, bits>(
x, x_thread, remaining);
for (int row = 0; row < results_per_simdgroup; row++) {
const device uint8_t* wl =
(const device uint8_t*)(w + row * in_vec_size_w);
const device T* sl = scales + row * in_vec_size_g;
const device T* bl = biases + row * in_vec_size_g;
for (int row = 0; row < results_per_simdgroup; row++) {
const device uint8_t* wl =
(const device uint8_t*)(w + row * in_vec_size_w);
const device T* sl = scales + row * in_vec_size_g;
const device T* bl = biases + row * in_vec_size_g;
U s = sl[0];
U b = bl[0];
result[row] += qdot_safe<U, values_per_thread, bits>(
wl, x_thread, s, b, sum, remaining);
U s = sl[0];
U b = bl[0];
result[row] += qdot_safe<U, values_per_thread, bits>(
wl, x_thread, s, b, sum, remaining);
}
}
for (int row = 0; row < results_per_simdgroup; row++) {
result[row] = simd_sum(result[row]);
if (simd_lid == 0) {