Fix OOB access in qmv (#1577)

* fix oob access in qmv

* skip more
This commit is contained in:
Alex Barron 2024-11-08 15:41:30 -08:00 committed by GitHub
parent c1fe1ef081
commit 111fefd5e9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -619,8 +619,9 @@ METAL_FUNC void qmv_impl(
static_cast<int>(in_vec_size - k - simd_lid * values_per_thread), static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
0, 0,
values_per_thread); values_per_thread);
U sum = if (remaining > 0) {
load_vector_safe<T, U, values_per_thread, bits>(x, x_thread, remaining); U sum = load_vector_safe<T, U, values_per_thread, bits>(
x, x_thread, remaining);
for (int row = 0; row < results_per_simdgroup; row++) { for (int row = 0; row < results_per_simdgroup; row++) {
const device uint8_t* wl = const device uint8_t* wl =
@ -633,7 +634,7 @@ METAL_FUNC void qmv_impl(
result[row] += qdot_safe<U, values_per_thread, bits>( result[row] += qdot_safe<U, values_per_thread, bits>(
wl, x_thread, s, b, sum, remaining); wl, x_thread, s, b, sum, remaining);
} }
}
for (int row = 0; row < results_per_simdgroup; row++) { for (int row = 0; row < results_per_simdgroup; row++) {
result[row] = simd_sum(result[row]); result[row] = simd_sum(result[row]);
if (simd_lid == 0) { if (simd_lid == 0) {