mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-24 19:11:17 +08:00
fix integer overflow in qmm (#2143)
This commit is contained in:
parent
ea890d8710
commit
e496c5a4b4
@ -1008,11 +1008,11 @@ METAL_FUNC void qmm_t_impl(
|
||||
|
||||
auto wl = (const device uint8_t*)w;
|
||||
|
||||
x += y_row * K;
|
||||
x += y_row * static_cast<int64_t>(K);
|
||||
wl += y_col * K_w;
|
||||
scales += y_col * K_g;
|
||||
biases += y_col * K_g;
|
||||
y += y_row * N + y_col;
|
||||
y += y_row * static_cast<int64_t>(N) + y_col;
|
||||
|
||||
// Make the x loader and mma operation
|
||||
const short num_els = min(BM, M - y_row);
|
||||
@ -1132,11 +1132,11 @@ METAL_FUNC void qmm_n_impl(
|
||||
// Set the block
|
||||
const int y_row = tid.y * BM;
|
||||
const int y_col = tid.x * BN;
|
||||
x += y_row * K;
|
||||
x += y_row * static_cast<int64_t>(K);
|
||||
wl += y_col * bytes_per_pack / pack_factor;
|
||||
scales += y_col / group_size;
|
||||
biases += y_col / group_size;
|
||||
y += y_row * N + y_col;
|
||||
y += y_row * static_cast<int64_t>(N) + y_col;
|
||||
|
||||
// Make the x loader and mma operation
|
||||
const short num_els = min(BM, M - y_row);
|
||||
|
Loading…
Reference in New Issue
Block a user