fix integer overflow in qmm

This commit is contained in:
Awni Hannun 2025-04-30 09:04:12 -07:00
parent 87720a8908
commit 45df803538

View File

@ -1008,11 +1008,11 @@ METAL_FUNC void qmm_t_impl(
auto wl = (const device uint8_t*)w;
x += y_row * K;
x += y_row * static_cast<int64_t>(K);
wl += y_col * K_w;
scales += y_col * K_g;
biases += y_col * K_g;
y += y_row * N + y_col;
y += y_row * static_cast<int64_t>(N) + y_col;
// Make the x loader and mma operation
const short num_els = min(BM, M - y_row);
@ -1132,11 +1132,11 @@ METAL_FUNC void qmm_n_impl(
// Set the block
const int y_row = tid.y * BM;
const int y_col = tid.x * BN;
x += y_row * K;
x += y_row * static_cast<int64_t>(K);
wl += y_col * bytes_per_pack / pack_factor;
scales += y_col / group_size;
biases += y_col / group_size;
y += y_row * N + y_col;
y += y_row * static_cast<int64_t>(N) + y_col;
// Make the x loader and mma operation
const short num_els = min(BM, M - y_row);