fix integer overflow in qmm (#2143)

This commit is contained in:
Awni Hannun 2025-04-30 09:28:56 -07:00 committed by GitHub
parent ea890d8710
commit e496c5a4b4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1008,11 +1008,11 @@ METAL_FUNC void qmm_t_impl(
auto wl = (const device uint8_t*)w;
x += y_row * K;
x += y_row * static_cast<int64_t>(K);
wl += y_col * K_w;
scales += y_col * K_g;
biases += y_col * K_g;
y += y_row * N + y_col;
y += y_row * static_cast<int64_t>(N) + y_col;
// Make the x loader and mma operation
const short num_els = min(BM, M - y_row);
@ -1132,11 +1132,11 @@ METAL_FUNC void qmm_n_impl(
// Set the block
const int y_row = tid.y * BM;
const int y_col = tid.x * BN;
x += y_row * K;
x += y_row * static_cast<int64_t>(K);
wl += y_col * bytes_per_pack / pack_factor;
scales += y_col / group_size;
biases += y_col / group_size;
y += y_row * N + y_col;
y += y_row * static_cast<int64_t>(N) + y_col;
// Make the x loader and mma operation
const short num_els = min(BM, M - y_row);